diff --git a/CHANGELOG.md b/CHANGELOG.md
index b628c4ee2070c..539f5a6628dac 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- [Remote Store] Add support to disable flush based on translog reader count ([#14027](https://github.com/opensearch-project/OpenSearch/pull/14027))
- [Query Insights] Add exporter support for top n queries ([#12982](https://github.com/opensearch-project/OpenSearch/pull/12982))
- [Query Insights] Add X-Opaque-Id to search request metadata for top n queries ([#13374](https://github.com/opensearch-project/OpenSearch/pull/13374))
+- Add support for query level resource usage tracking ([#13172](https://github.com/opensearch-project/OpenSearch/pull/13172))
### Dependencies
- Bump `com.github.spullara.mustache.java:compiler` from 0.9.10 to 0.9.13 ([#13329](https://github.com/opensearch-project/OpenSearch/pull/13329), [#13559](https://github.com/opensearch-project/OpenSearch/pull/13559))
diff --git a/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/ResourceUsageInfo.java b/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/ResourceUsageInfo.java
index a278b61894a65..e7b51c3389b52 100644
--- a/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/ResourceUsageInfo.java
+++ b/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/ResourceUsageInfo.java
@@ -104,6 +104,10 @@ public long getTotalValue() {
return endValue.get() - startValue;
}
+ public long getStartValue() {
+ return startValue;
+ }
+
@Override
public String toString() {
return String.valueOf(getTotalValue());
diff --git a/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/TaskResourceInfo.java b/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/TaskResourceInfo.java
new file mode 100644
index 0000000000000..373cdbfa7e9a1
--- /dev/null
+++ b/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/TaskResourceInfo.java
@@ -0,0 +1,225 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.core.tasks.resourcetracker;
+
+import org.opensearch.common.annotation.PublicApi;
+import org.opensearch.core.ParseField;
+import org.opensearch.core.common.Strings;
+import org.opensearch.core.common.io.stream.StreamInput;
+import org.opensearch.core.common.io.stream.StreamOutput;
+import org.opensearch.core.common.io.stream.Writeable;
+import org.opensearch.core.xcontent.ConstructingObjectParser;
+import org.opensearch.core.xcontent.MediaTypeRegistry;
+import org.opensearch.core.xcontent.ToXContentObject;
+import org.opensearch.core.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.opensearch.core.xcontent.ConstructingObjectParser.constructorArg;
+
+/**
+ * Task resource usage information with minimal information about the task
+ *
+ * Writeable TaskResourceInfo objects are used to represent resource usage
+ * information of running tasks, which can be propagated to coordinator node
+ * to infer query-level resource usage
+ *
+ * @opensearch.api
+ */
+@PublicApi(since = "2.15.0")
+public class TaskResourceInfo implements Writeable, ToXContentObject {
+ private final String action;
+ private final long taskId;
+ private final long parentTaskId;
+ private final String nodeId;
+ private final TaskResourceUsage taskResourceUsage;
+
+ private static final ParseField ACTION = new ParseField("action");
+ private static final ParseField TASK_ID = new ParseField("taskId");
+ private static final ParseField PARENT_TASK_ID = new ParseField("parentTaskId");
+ private static final ParseField NODE_ID = new ParseField("nodeId");
+ private static final ParseField TASK_RESOURCE_USAGE = new ParseField("taskResourceUsage");
+
+ public TaskResourceInfo(
+ final String action,
+ final long taskId,
+ final long parentTaskId,
+ final String nodeId,
+ final TaskResourceUsage taskResourceUsage
+ ) {
+ this.action = action;
+ this.taskId = taskId;
+ this.parentTaskId = parentTaskId;
+ this.nodeId = nodeId;
+ this.taskResourceUsage = taskResourceUsage;
+ }
+
+ public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ "task_resource_info",
+ a -> new Builder().setAction((String) a[0])
+ .setTaskId((Long) a[1])
+ .setParentTaskId((Long) a[2])
+ .setNodeId((String) a[3])
+ .setTaskResourceUsage((TaskResourceUsage) a[4])
+ .build()
+ );
+
+ static {
+ PARSER.declareString(constructorArg(), ACTION);
+ PARSER.declareLong(constructorArg(), TASK_ID);
+ PARSER.declareLong(constructorArg(), PARENT_TASK_ID);
+ PARSER.declareString(constructorArg(), NODE_ID);
+ PARSER.declareObject(constructorArg(), TaskResourceUsage.PARSER, TASK_RESOURCE_USAGE);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ builder.field(ACTION.getPreferredName(), this.action);
+ builder.field(TASK_ID.getPreferredName(), this.taskId);
+ builder.field(PARENT_TASK_ID.getPreferredName(), this.parentTaskId);
+ builder.field(NODE_ID.getPreferredName(), this.nodeId);
+ builder.startObject(TASK_RESOURCE_USAGE.getPreferredName());
+ this.taskResourceUsage.toXContent(builder, params);
+ builder.endObject();
+ builder.endObject();
+ return builder;
+ }
+
+ /**
+ * Builder for {@link TaskResourceInfo}
+ */
+ public static class Builder {
+ private TaskResourceUsage taskResourceUsage;
+ private String action;
+ private long taskId;
+ private long parentTaskId;
+ private String nodeId;
+
+ public Builder setTaskResourceUsage(final TaskResourceUsage taskResourceUsage) {
+ this.taskResourceUsage = taskResourceUsage;
+ return this;
+ }
+
+ public Builder setAction(final String action) {
+ this.action = action;
+ return this;
+ }
+
+ public Builder setTaskId(final long taskId) {
+ this.taskId = taskId;
+ return this;
+ }
+
+ public Builder setParentTaskId(final long parentTaskId) {
+ this.parentTaskId = parentTaskId;
+ return this;
+ }
+
+ public Builder setNodeId(final String nodeId) {
+ this.nodeId = nodeId;
+ return this;
+ }
+
+ public TaskResourceInfo build() {
+ return new TaskResourceInfo(action, taskId, parentTaskId, nodeId, taskResourceUsage);
+ }
+ }
+
+ /**
+ * Read task info from a stream.
+ *
+ * @param in StreamInput to read
+ * @return {@link TaskResourceInfo}
+ * @throws IOException IOException
+ */
+ public static TaskResourceInfo readFromStream(StreamInput in) throws IOException {
+ return new TaskResourceInfo.Builder().setAction(in.readString())
+ .setTaskId(in.readLong())
+ .setParentTaskId(in.readLong())
+ .setNodeId(in.readString())
+ .setTaskResourceUsage(TaskResourceUsage.readFromStream(in))
+ .build();
+ }
+
+ /**
+ * Get TaskResourceUsage
+ *
+ * @return taskResourceUsage
+ */
+ public TaskResourceUsage getTaskResourceUsage() {
+ return taskResourceUsage;
+ }
+
+ /**
+ * Get parent task id
+ *
+ * @return parent task id
+ */
+ public long getParentTaskId() {
+ return parentTaskId;
+ }
+
+ /**
+ * Get task id
+ * @return task id
+ */
+ public long getTaskId() {
+ return taskId;
+ }
+
+ /**
+ * Get node id
+ * @return node id
+ */
+ public String getNodeId() {
+ return nodeId;
+ }
+
+ /**
+ * Get task action
+ * @return task action
+ */
+ public String getAction() {
+ return action;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeString(action);
+ out.writeLong(taskId);
+ out.writeLong(parentTaskId);
+ out.writeString(nodeId);
+ taskResourceUsage.writeTo(out);
+ }
+
+ @Override
+ public String toString() {
+ return Strings.toString(MediaTypeRegistry.JSON, this);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null || obj.getClass() != TaskResourceInfo.class) {
+ return false;
+ }
+ TaskResourceInfo other = (TaskResourceInfo) obj;
+ return action.equals(other.action)
+ && taskId == other.taskId
+ && parentTaskId == other.parentTaskId
+ && Objects.equals(nodeId, other.nodeId)
+ && taskResourceUsage.equals(other.taskResourceUsage);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(action, taskId, parentTaskId, nodeId, taskResourceUsage);
+ }
+}
diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java
index 9bf4a4b1e18f1..f0fc05c595d6f 100644
--- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java
+++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java
@@ -51,6 +51,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ShardOperationFailedException;
import org.opensearch.core.index.shard.ShardId;
+import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.internal.AliasFilter;
@@ -469,6 +470,10 @@ private void onRequestEnd(SearchRequestContext searchRequestContext) {
this.searchRequestContext.getSearchRequestOperationsListener().onRequestEnd(this, searchRequestContext);
}
+ private void onRequestFailure(SearchRequestContext searchRequestContext) {
+ this.searchRequestContext.getSearchRequestOperationsListener().onRequestFailure(this, searchRequestContext);
+ }
+
private void executePhase(SearchPhase phase) {
Span phaseSpan = tracer.startSpan(SpanCreationContext.server().name("[phase/" + phase.getName() + "]"));
try (final SpanScope scope = tracer.withSpanInScope(phaseSpan)) {
@@ -507,6 +512,7 @@ ShardSearchFailure[] buildShardFailures() {
private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget shard, final SearchShardIterator shardIt, Exception e) {
// we always add the shard failure for a specific shard instance
// we do make sure to clean it on a successful response from a shard
+ setPhaseResourceUsages();
onShardFailure(shardIndex, shard, e);
SearchShardTarget nextShard = FailAwareWeightedRouting.getInstance()
.findNext(shardIt, clusterState, e, () -> totalOps.incrementAndGet());
@@ -618,9 +624,15 @@ protected void onShardResult(Result result, SearchShardIterator shardIt) {
if (logger.isTraceEnabled()) {
logger.trace("got first-phase result from {}", result != null ? result.getSearchShardTarget() : null);
}
+ this.setPhaseResourceUsages();
results.consumeResult(result, () -> onShardResultConsumed(result, shardIt));
}
+ public void setPhaseResourceUsages() {
+ TaskResourceInfo taskResourceUsage = searchRequestContext.getTaskResourceUsageSupplier().get();
+ searchRequestContext.recordPhaseResourceUsage(taskResourceUsage);
+ }
+
private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
successfulOps.incrementAndGet();
// clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level
@@ -751,6 +763,7 @@ public void sendSearchResponse(InternalSearchResponse internalSearchResponse, At
@Override
public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) {
+ setPhaseResourceUsages();
if (currentPhaseHasLifecycle) {
this.searchRequestContext.getSearchRequestOperationsListener().onPhaseFailure(this, cause);
}
@@ -780,6 +793,7 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) {
});
}
Releasables.close(releasables);
+ onRequestFailure(searchRequestContext);
listener.onFailure(exception);
}
diff --git a/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java
index ebb2f33f8f37d..2ad7f8a29896c 100644
--- a/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java
+++ b/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java
@@ -240,6 +240,7 @@ private void executeFetch(
public void innerOnResponse(FetchSearchResult result) {
try {
progressListener.notifyFetchResult(shardIndex);
+ context.setPhaseResourceUsages();
counter.onResult(result);
} catch (Exception e) {
context.onPhaseFailure(FetchSearchPhase.this, "", e);
@@ -254,6 +255,7 @@ public void onFailure(Exception e) {
e
);
progressListener.notifyFetchFailure(shardIndex, shardTarget, e);
+ context.setPhaseResourceUsages();
counter.onFailure(shardIndex, shardTarget, e);
} finally {
// the search context might not be cleared on the node where the fetch was executed for example
diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java
index df451e0745e3c..55f2a22749e70 100644
--- a/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java
+++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java
@@ -150,4 +150,9 @@ default void sendReleaseSearchContext(
* Registers a {@link Releasable} that will be closed when the search request finishes or fails.
*/
void addReleasable(Releasable releasable);
+
+ /**
+ * Set the resource usage info for this phase
+ */
+ void setPhaseResourceUsages();
}
diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java b/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java
index 5b133ba0554f4..111d9c64550b3 100644
--- a/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java
+++ b/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java
@@ -8,13 +8,20 @@
package org.opensearch.action.search;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.annotation.InternalApi;
+import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo;
+import java.util.ArrayList;
import java.util.EnumMap;
import java.util.HashMap;
+import java.util.List;
import java.util.Locale;
import java.util.Map;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.function.Supplier;
/**
* This class holds request-level context for search queries at the coordinator node
@@ -23,6 +30,7 @@
*/
@InternalApi
public class SearchRequestContext {
+ private static final Logger logger = LogManager.getLogger();
private final SearchRequestOperationsListener searchRequestOperationsListener;
private long absoluteStartNanos;
private final Map phaseTookMap;
@@ -30,13 +38,21 @@ public class SearchRequestContext {
private final EnumMap shardStats;
private final SearchRequest searchRequest;
-
- SearchRequestContext(final SearchRequestOperationsListener searchRequestOperationsListener, final SearchRequest searchRequest) {
+ private final LinkedBlockingQueue phaseResourceUsage;
+ private final Supplier taskResourceUsageSupplier;
+
+ SearchRequestContext(
+ final SearchRequestOperationsListener searchRequestOperationsListener,
+ final SearchRequest searchRequest,
+ final Supplier taskResourceUsageSupplier
+ ) {
this.searchRequestOperationsListener = searchRequestOperationsListener;
this.absoluteStartNanos = System.nanoTime();
this.phaseTookMap = new HashMap<>();
this.shardStats = new EnumMap<>(ShardStatsFieldNames.class);
this.searchRequest = searchRequest;
+ this.phaseResourceUsage = new LinkedBlockingQueue<>();
+ this.taskResourceUsageSupplier = taskResourceUsageSupplier;
}
SearchRequestOperationsListener getSearchRequestOperationsListener() {
@@ -108,6 +124,20 @@ String formattedShardStats() {
}
}
+ public Supplier getTaskResourceUsageSupplier() {
+ return taskResourceUsageSupplier;
+ }
+
+ public void recordPhaseResourceUsage(TaskResourceInfo usage) {
+ if (usage != null) {
+ this.phaseResourceUsage.add(usage);
+ }
+ }
+
+ public List getPhaseResourceUsage() {
+ return new ArrayList<>(phaseResourceUsage);
+ }
+
public SearchRequest getRequest() {
return searchRequest;
}
diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java b/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java
index b944572cef122..61f19977ae5ce 100644
--- a/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java
+++ b/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java
@@ -51,6 +51,8 @@ protected void onRequestStart(SearchRequestContext searchRequestContext) {}
protected void onRequestEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {}
+ protected void onRequestFailure(SearchPhaseContext context, SearchRequestContext searchRequestContext) {}
+
protected boolean isEnabled(SearchRequest searchRequest) {
return isEnabled();
}
@@ -133,6 +135,17 @@ public void onRequestEnd(SearchPhaseContext context, SearchRequestContext search
}
}
+ @Override
+ public void onRequestFailure(SearchPhaseContext context, SearchRequestContext searchRequestContext) {
+ for (SearchRequestOperationsListener listener : listeners) {
+ try {
+ listener.onRequestFailure(context, searchRequestContext);
+ } catch (Exception e) {
+ logger.warn(() -> new ParameterizedMessage("onRequestFailure listener [{}] failed", listener), e);
+ }
+ }
+ }
+
public List getListeners() {
return listeners;
}
diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java
index 143b01af3f62f..6e380775355a2 100644
--- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java
+++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java
@@ -87,6 +87,7 @@
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
+import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.telemetry.metrics.MetricsRegistry;
import org.opensearch.telemetry.tracing.Span;
import org.opensearch.telemetry.tracing.SpanBuilder;
@@ -186,6 +187,7 @@ public class TransportSearchAction extends HandledTransportAction) SearchRequest::new);
this.client = client;
@@ -224,6 +227,7 @@ public TransportSearchAction(
clusterService.getClusterSettings()
.addSettingsUpdateConsumer(SEARCH_QUERY_METRICS_ENABLED_SETTING, this::setSearchQueryMetricsEnabled);
this.tracer = tracer;
+ this.taskResourceTrackingService = taskResourceTrackingService;
}
private void setSearchQueryMetricsEnabled(boolean searchQueryMetricsEnabled) {
@@ -451,7 +455,11 @@ private void executeRequest(
logger,
TraceableSearchRequestOperationsListener.create(tracer, requestSpan)
);
- SearchRequestContext searchRequestContext = new SearchRequestContext(requestOperationsListeners, originalSearchRequest);
+ SearchRequestContext searchRequestContext = new SearchRequestContext(
+ requestOperationsListeners,
+ originalSearchRequest,
+ taskResourceTrackingService::getTaskResourceUsageFromThreadContext
+ );
searchRequestContext.getSearchRequestOperationsListener().onRequestStart(searchRequestContext);
PipelinedRequest searchRequest;
diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java
index 6580b0e0085ef..0b1aa9a4a759a 100644
--- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java
+++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java
@@ -483,6 +483,15 @@ public void addResponseHeader(final String key, final String value) {
addResponseHeader(key, value, v -> v);
}
+ /**
+ * Remove the {@code value} for the specified {@code key}.
+ *
+ * @param key the header name
+ */
+ public void removeResponseHeader(final String key) {
+ threadLocal.get().responseHeaders.remove(key);
+ }
+
/**
* Add the {@code value} for the specified {@code key} with the specified {@code uniqueValue} used for de-duplication. Any duplicate
* {@code value} after applying {@code uniqueValue} is ignored.
diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java
index cb1f2caa082fc..f7a901335f34a 100644
--- a/server/src/main/java/org/opensearch/node/Node.java
+++ b/server/src/main/java/org/opensearch/node/Node.java
@@ -1261,7 +1261,8 @@ protected Node(
searchModule.getFetchPhase(),
responseCollectorService,
circuitBreakerService,
- searchModule.getIndexSearcherExecutor(threadPool)
+ searchModule.getIndexSearcherExecutor(threadPool),
+ taskResourceTrackingService
);
final List> tasksExecutors = pluginsService.filterPlugins(PersistentTaskPlugin.class)
@@ -1905,7 +1906,8 @@ protected SearchService newSearchService(
FetchPhase fetchPhase,
ResponseCollectorService responseCollectorService,
CircuitBreakerService circuitBreakerService,
- Executor indexSearcherExecutor
+ Executor indexSearcherExecutor,
+ TaskResourceTrackingService taskResourceTrackingService
) {
return new SearchService(
clusterService,
@@ -1917,7 +1919,8 @@ protected SearchService newSearchService(
fetchPhase,
responseCollectorService,
circuitBreakerService,
- indexSearcherExecutor
+ indexSearcherExecutor,
+ taskResourceTrackingService
);
}
diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java
index d371d69a57804..45f111d889522 100644
--- a/server/src/main/java/org/opensearch/search/SearchService.java
+++ b/server/src/main/java/org/opensearch/search/SearchService.java
@@ -137,6 +137,7 @@
import org.opensearch.search.sort.SortOrder;
import org.opensearch.search.suggest.Suggest;
import org.opensearch.search.suggest.completion.CompletionSuggestion;
+import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.threadpool.Scheduler.Cancellable;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.threadpool.ThreadPool.Names;
@@ -338,6 +339,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
private final AtomicInteger openPitContexts = new AtomicInteger();
private final String sessionId = UUIDs.randomBase64UUID();
private final Executor indexSearcherExecutor;
+ private final TaskResourceTrackingService taskResourceTrackingService;
public SearchService(
ClusterService clusterService,
@@ -349,7 +351,8 @@ public SearchService(
FetchPhase fetchPhase,
ResponseCollectorService responseCollectorService,
CircuitBreakerService circuitBreakerService,
- Executor indexSearcherExecutor
+ Executor indexSearcherExecutor,
+ TaskResourceTrackingService taskResourceTrackingService
) {
Settings settings = clusterService.getSettings();
this.threadPool = threadPool;
@@ -366,6 +369,7 @@ public SearchService(
circuitBreakerService.getBreaker(CircuitBreaker.REQUEST)
);
this.indexSearcherExecutor = indexSearcherExecutor;
+ this.taskResourceTrackingService = taskResourceTrackingService;
TimeValue keepAliveInterval = KEEPALIVE_INTERVAL_SETTING.get(settings);
setKeepAlives(DEFAULT_KEEPALIVE_SETTING.get(settings), MAX_KEEPALIVE_SETTING.get(settings));
setPitKeepAlives(DEFAULT_KEEPALIVE_SETTING.get(settings), MAX_PIT_KEEPALIVE_SETTING.get(settings));
@@ -558,6 +562,8 @@ private DfsSearchResult executeDfsPhase(ShardSearchRequest request, SearchShardT
logger.trace("Dfs phase failed", e);
processFailure(readerContext, e);
throw e;
+ } finally {
+ taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId());
}
}
@@ -660,6 +666,8 @@ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchSh
logger.trace("Query phase failed", e);
processFailure(readerContext, e);
throw e;
+ } finally {
+ taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId());
}
}
@@ -705,6 +713,8 @@ public void executeQueryPhase(
logger.trace("Query phase failed", e);
// we handle the failure in the failure listener below
throw e;
+ } finally {
+ taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId());
}
}, wrapFailureListener(listener, readerContext, markAsUsed));
}
@@ -737,6 +747,8 @@ public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task,
logger.trace("Query phase failed", e);
// we handle the failure in the failure listener below
throw e;
+ } finally {
+ taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId());
}
}, wrapFailureListener(listener, readerContext, markAsUsed));
}
@@ -786,6 +798,8 @@ public void executeFetchPhase(
logger.trace("Fetch phase failed", e);
// we handle the failure in the failure listener below
throw e;
+ } finally {
+ taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId());
}
}, wrapFailureListener(listener, readerContext, markAsUsed));
}
@@ -816,6 +830,8 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A
assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
// we handle the failure in the failure listener below
throw e;
+ } finally {
+ taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId());
}
}, wrapFailureListener(listener, readerContext, markAsUsed));
}
@@ -1749,6 +1765,7 @@ public CanMatchResponse(boolean canMatch, MinAndMax> estimatedMinAndMax) {
@Override
public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
out.writeBoolean(canMatch);
out.writeOptionalWriteable(estimatedMinAndMax);
}
diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java
index a21a454a65d0e..0fa65bc16516f 100644
--- a/server/src/main/java/org/opensearch/tasks/Task.java
+++ b/server/src/main/java/org/opensearch/tasks/Task.java
@@ -476,6 +476,18 @@ public void stopThreadResourceTracking(long threadId, ResourceStatsType statsTyp
throw new IllegalStateException("cannot update final values if active thread resource entry is not present");
}
+ public ThreadResourceInfo getActiveThreadResourceInfo(long threadId, ResourceStatsType statsType) {
+ final List threadResourceInfoList = resourceStats.get(threadId);
+ if (threadResourceInfoList != null) {
+ for (ThreadResourceInfo threadResourceInfo : threadResourceInfoList) {
+ if (threadResourceInfo.getStatsType() == statsType && threadResourceInfo.isActive()) {
+ return threadResourceInfo;
+ }
+ }
+ }
+ return null;
+ }
+
/**
* Individual tasks can override this if they want to support task resource tracking. We just need to make sure that
* the ThreadPool on which the task runs on have runnable wrapper similar to
diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java
index f32559f6314c0..564eff6c10df6 100644
--- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java
+++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java
@@ -14,6 +14,7 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.ExceptionsHelper;
+import org.opensearch.action.search.SearchShardTask;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.ClusterSettings;
@@ -22,12 +23,23 @@
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.common.util.concurrent.ConcurrentMapLong;
import org.opensearch.common.util.concurrent.ThreadContext;
+import org.opensearch.common.xcontent.XContentHelper;
+import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.tasks.resourcetracker.ResourceStats;
+import org.opensearch.core.tasks.resourcetracker.ResourceStatsType;
+import org.opensearch.core.tasks.resourcetracker.ResourceUsageInfo;
import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric;
+import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo;
+import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage;
import org.opensearch.core.tasks.resourcetracker.ThreadResourceInfo;
+import org.opensearch.core.xcontent.DeprecationHandler;
+import org.opensearch.core.xcontent.MediaTypeRegistry;
+import org.opensearch.core.xcontent.NamedXContentRegistry;
+import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.threadpool.ThreadPool;
+import java.io.IOException;
import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.Collections;
@@ -51,6 +63,7 @@ public class TaskResourceTrackingService implements RunnableTaskExecutionListene
Setting.Property.NodeScope
);
public static final String TASK_ID = "TASK_ID";
+ public static final String TASK_RESOURCE_USAGE = "TASK_RESOURCE_USAGE";
private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean();
@@ -261,6 +274,89 @@ private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) {
return storedContext;
}
+ /**
+ * Get the current task level resource usage.
+ *
+ * @param task {@link SearchShardTask}
+ * @param nodeId the local nodeId
+ */
+ public void writeTaskResourceUsage(SearchShardTask task, String nodeId) {
+ try {
+ // Get resource usages from when the task started
+ ThreadResourceInfo threadResourceInfo = task.getActiveThreadResourceInfo(
+ Thread.currentThread().getId(),
+ ResourceStatsType.WORKER_STATS
+ );
+ if (threadResourceInfo == null) {
+ return;
+ }
+ Map startValues = threadResourceInfo.getResourceUsageInfo().getStatsInfo();
+ if (!(startValues.containsKey(ResourceStats.CPU) && startValues.containsKey(ResourceStats.MEMORY))) {
+ return;
+ }
+ // Get current resource usages
+ ResourceUsageMetric[] endValues = getResourceUsageMetricsForThread(Thread.currentThread().getId());
+ long cpu = -1, mem = -1;
+ for (ResourceUsageMetric endValue : endValues) {
+ if (endValue.getStats() == ResourceStats.MEMORY) {
+ mem = endValue.getValue();
+ } else if (endValue.getStats() == ResourceStats.CPU) {
+ cpu = endValue.getValue();
+ }
+ }
+ if (cpu == -1 || mem == -1) {
+ logger.debug("Invalid resource usage value, cpu [{}], memory [{}]: ", cpu, mem);
+ return;
+ }
+
+ // Build task resource usage info
+ TaskResourceInfo taskResourceInfo = new TaskResourceInfo.Builder().setAction(task.getAction())
+ .setTaskId(task.getId())
+ .setParentTaskId(task.getParentTaskId().getId())
+ .setNodeId(nodeId)
+ .setTaskResourceUsage(
+ new TaskResourceUsage(
+ cpu - startValues.get(ResourceStats.CPU).getStartValue(),
+ mem - startValues.get(ResourceStats.MEMORY).getStartValue()
+ )
+ )
+ .build();
+ // Remove the existing TASK_RESOURCE_USAGE header since it would have come from an earlier phase in the same request.
+ synchronized (this) {
+ threadPool.getThreadContext().removeResponseHeader(TASK_RESOURCE_USAGE);
+ threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString());
+ }
+ } catch (Exception e) {
+ logger.debug("Error during writing task resource usage: ", e);
+ }
+ }
+
+ /**
+ * Get the task resource usages from {@link ThreadContext}
+ *
+ * @return {@link TaskResourceInfo}
+ */
+ public TaskResourceInfo getTaskResourceUsageFromThreadContext() {
+ List taskResourceUsages = threadPool.getThreadContext().getResponseHeaders().get(TASK_RESOURCE_USAGE);
+ if (taskResourceUsages != null && taskResourceUsages.size() > 0) {
+ String usage = taskResourceUsages.get(0);
+ try {
+ if (usage != null && !usage.isEmpty()) {
+ XContentParser parser = XContentHelper.createParser(
+ NamedXContentRegistry.EMPTY,
+ DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
+ new BytesArray(usage),
+ MediaTypeRegistry.JSON
+ );
+ return TaskResourceInfo.PARSER.apply(parser, null);
+ }
+ } catch (IOException e) {
+ logger.debug("fail to parse phase resource usages: ", e);
+ }
+ }
+ return null;
+ }
+
/**
* Listener that gets invoked when a task execution completes.
*/
diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java
index 7dcbf213d6c9d..27336e86e52b0 100644
--- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java
+++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java
@@ -49,6 +49,8 @@
import org.opensearch.core.common.breaker.NoopCircuitBreaker;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
+import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo;
+import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.shard.ShardNotFoundException;
import org.opensearch.search.SearchPhaseResult;
@@ -87,6 +89,7 @@
import java.util.function.BiFunction;
import java.util.stream.IntStream;
+import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;
@@ -123,7 +126,8 @@ private AbstractSearchAsyncAction createAction(
ArraySearchPhaseResults results,
ActionListener listener,
final boolean controlled,
- final AtomicLong expected
+ final AtomicLong expected,
+ final TaskResourceUsage resourceUsage
) {
return createAction(
request,
@@ -133,6 +137,7 @@ private AbstractSearchAsyncAction createAction(
false,
false,
expected,
+ resourceUsage,
new SearchShardIterator(null, null, Collections.emptyList(), null)
);
}
@@ -145,6 +150,7 @@ private AbstractSearchAsyncAction createAction(
final boolean failExecutePhaseOnShard,
final boolean catchExceptionWhenExecutePhaseOnShard,
final AtomicLong expected,
+ final TaskResourceUsage resourceUsage,
final SearchShardIterator... shards
) {
@@ -166,6 +172,14 @@ private AbstractSearchAsyncAction createAction(
return null;
};
+ TaskResourceInfo taskResourceInfo = new TaskResourceInfo.Builder().setTaskResourceUsage(resourceUsage)
+ .setTaskId(randomLong())
+ .setParentTaskId(randomLong())
+ .setAction(randomAlphaOfLengthBetween(1, 5))
+ .setNodeId(randomAlphaOfLengthBetween(1, 5))
+ .build();
+ threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString());
+
return new AbstractSearchAsyncAction(
"test",
logger,
@@ -186,7 +200,8 @@ private AbstractSearchAsyncAction createAction(
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
- request
+ request,
+ () -> null
),
NoopTracer.INSTANCE
) {
@@ -248,7 +263,8 @@ private void runTestTook(final boolean controlled) {
new ArraySearchPhaseResults<>(10),
null,
controlled,
- expected
+ expected,
+ new TaskResourceUsage(0, 0)
);
final long actual = action.buildTookInMillis();
if (controlled) {
@@ -268,7 +284,8 @@ public void testBuildShardSearchTransportRequest() {
new ArraySearchPhaseResults<>(10),
null,
false,
- expected
+ expected,
+ new TaskResourceUsage(randomLong(), randomLong())
);
String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10);
SearchShardIterator iterator = new SearchShardIterator(
@@ -291,19 +308,39 @@ public void testBuildShardSearchTransportRequest() {
public void testBuildSearchResponse() {
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(randomBoolean());
ArraySearchPhaseResults phaseResults = new ArraySearchPhaseResults<>(10);
- AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, null, false, new AtomicLong());
+ TaskResourceUsage taskResourceUsage = new TaskResourceUsage(randomLong(), randomLong());
+ AbstractSearchAsyncAction action = createAction(
+ searchRequest,
+ phaseResults,
+ null,
+ false,
+ new AtomicLong(),
+ taskResourceUsage
+ );
InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null);
assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations());
assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest());
assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile());
assertSame(searchResponse.getHits(), internalSearchResponse.hits());
+ List resourceUsages = threadPool.getThreadContext().getResponseHeaders().get(TASK_RESOURCE_USAGE);
+ assertNotNull(resourceUsages);
+ assertEquals(1, resourceUsages.size());
+ assertTrue(resourceUsages.get(0).contains(Long.toString(taskResourceUsage.getCpuTimeInNanos())));
+ assertTrue(resourceUsages.get(0).contains(Long.toString(taskResourceUsage.getMemoryInBytes())));
}
public void testBuildSearchResponseAllowPartialFailures() {
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
final ArraySearchPhaseResults queryResult = new ArraySearchPhaseResults<>(10);
- AbstractSearchAsyncAction action = createAction(searchRequest, queryResult, null, false, new AtomicLong());
+ AbstractSearchAsyncAction action = createAction(
+ searchRequest,
+ queryResult,
+ null,
+ false,
+ new AtomicLong(),
+ new TaskResourceUsage(randomLong(), randomLong())
+ );
action.onShardFailure(
0,
new SearchShardTarget("node", new ShardId("index", "index-uuid", 0), null, OriginalIndices.NONE),
@@ -325,7 +362,14 @@ public void testSendSearchResponseDisallowPartialFailures() {
List> nodeLookups = new ArrayList<>();
int numFailures = randomIntBetween(1, 5);
ArraySearchPhaseResults phaseResults = phaseResults(requestIds, nodeLookups, numFailures);
- AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong());
+ AbstractSearchAsyncAction action = createAction(
+ searchRequest,
+ phaseResults,
+ listener,
+ false,
+ new AtomicLong(),
+ new TaskResourceUsage(randomLong(), randomLong())
+ );
for (int i = 0; i < numFailures; i++) {
ShardId failureShardId = new ShardId("index", "index-uuid", i);
String failureClusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10);
@@ -404,7 +448,14 @@ public void testOnPhaseFailure() {
Set requestIds = new HashSet<>();
List> nodeLookups = new ArrayList<>();
ArraySearchPhaseResults phaseResults = phaseResults(requestIds, nodeLookups, 0);
- AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong());
+ AbstractSearchAsyncAction action = createAction(
+ searchRequest,
+ phaseResults,
+ listener,
+ false,
+ new AtomicLong(),
+ new TaskResourceUsage(randomLong(), randomLong())
+ );
action.onPhaseFailure(new SearchPhase("test") {
@Override
@@ -428,7 +479,14 @@ public void testShardNotAvailableWithDisallowPartialFailures() {
ActionListener listener = ActionListener.wrap(response -> fail("onResponse should not be called"), exception::set);
int numShards = randomIntBetween(2, 10);
ArraySearchPhaseResults phaseResults = new ArraySearchPhaseResults<>(numShards);
- AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong());
+ AbstractSearchAsyncAction action = createAction(
+ searchRequest,
+ phaseResults,
+ listener,
+ false,
+ new AtomicLong(),
+ new TaskResourceUsage(randomLong(), randomLong())
+ );
// skip one to avoid the "all shards failed" failure.
SearchShardIterator skipIterator = new SearchShardIterator(null, null, Collections.emptyList(), null);
skipIterator.resetAndSkip();
@@ -450,7 +508,14 @@ public void testShardNotAvailableWithIgnoreUnavailable() {
ActionListener listener = ActionListener.wrap(response -> {}, exception::set);
int numShards = randomIntBetween(2, 10);
ArraySearchPhaseResults phaseResults = new ArraySearchPhaseResults<>(numShards);
- AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong());
+ AbstractSearchAsyncAction action = createAction(
+ searchRequest,
+ phaseResults,
+ listener,
+ false,
+ new AtomicLong(),
+ new TaskResourceUsage(randomLong(), randomLong())
+ );
// skip one to avoid the "all shards failed" failure.
SearchShardIterator skipIterator = new SearchShardIterator(null, null, Collections.emptyList(), null);
skipIterator.resetAndSkip();
@@ -521,6 +586,7 @@ public void onFailure(Exception e) {
true,
false,
new AtomicLong(),
+ new TaskResourceUsage(randomLong(), randomLong()),
shards
);
action.run();
@@ -568,6 +634,7 @@ public void onFailure(Exception e) {
false,
false,
new AtomicLong(),
+ new TaskResourceUsage(randomLong(), randomLong()),
shards
);
action.run();
@@ -620,6 +687,7 @@ public void onFailure(Exception e) {
false,
catchExceptionWhenExecutePhaseOnShard,
new AtomicLong(),
+ new TaskResourceUsage(randomLong(), randomLong()),
shards
);
action.run();
@@ -771,7 +839,8 @@ private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAct
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(searchRequestOperationsListeners, logger),
- searchRequest
+ searchRequest,
+ () -> null
),
NoopTracer.INSTANCE
);
@@ -825,7 +894,8 @@ private SearchQueryThenFetchAsyncAction createSearchQueryThenFetchAsyncAction(
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(searchRequestOperationsListeners, logger),
- searchRequest
+ searchRequest,
+ () -> null
),
NoopTracer.INSTANCE
) {
diff --git a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java
index 1881c705fe6b3..bb51aeaeee9dd 100644
--- a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java
+++ b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java
@@ -170,7 +170,7 @@ public void run() throws IOException {
}
},
SearchResponse.Clusters.EMPTY,
- new SearchRequestContext(searchRequestOperationsListener, searchRequest),
+ new SearchRequestContext(searchRequestOperationsListener, searchRequest, () -> null),
NoopTracer.INSTANCE
);
@@ -268,7 +268,7 @@ public void run() throws IOException {
}
},
SearchResponse.Clusters.EMPTY,
- new SearchRequestContext(searchRequestOperationsListener, searchRequest),
+ new SearchRequestContext(searchRequestOperationsListener, searchRequest, () -> null),
NoopTracer.INSTANCE
);
@@ -366,7 +366,7 @@ public void sendCanMatch(
new ArraySearchPhaseResults<>(iter.size()),
randomIntBetween(1, 32),
SearchResponse.Clusters.EMPTY,
- new SearchRequestContext(searchRequestOperationsListener, searchRequest),
+ new SearchRequestContext(searchRequestOperationsListener, searchRequest, () -> null),
NoopTracer.INSTANCE
) {
@Override
@@ -396,7 +396,7 @@ protected void executePhaseOnShard(
);
},
SearchResponse.Clusters.EMPTY,
- new SearchRequestContext(searchRequestOperationsListener, searchRequest),
+ new SearchRequestContext(searchRequestOperationsListener, searchRequest, () -> null),
NoopTracer.INSTANCE
);
@@ -488,7 +488,7 @@ public void run() {
}
},
SearchResponse.Clusters.EMPTY,
- new SearchRequestContext(searchRequestOperationsListener, searchRequest),
+ new SearchRequestContext(searchRequestOperationsListener, searchRequest, () -> null),
NoopTracer.INSTANCE
);
@@ -595,7 +595,7 @@ public void run() {
}
},
SearchResponse.Clusters.EMPTY,
- new SearchRequestContext(searchRequestOperationsListener, searchRequest),
+ new SearchRequestContext(searchRequestOperationsListener, searchRequest, () -> null),
NoopTracer.INSTANCE
);
@@ -658,7 +658,8 @@ public void sendCanMatch(
ExecutorService executor = OpenSearchExecutors.newDirectExecutorService();
SearchRequestContext searchRequestContext = new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
- searchRequest
+ searchRequest,
+ () -> null
);
SearchPhaseController controller = new SearchPhaseController(
diff --git a/server/src/test/java/org/opensearch/action/search/MockSearchPhaseContext.java b/server/src/test/java/org/opensearch/action/search/MockSearchPhaseContext.java
index cc10da8fc1f12..2f3e462f741b8 100644
--- a/server/src/test/java/org/opensearch/action/search/MockSearchPhaseContext.java
+++ b/server/src/test/java/org/opensearch/action/search/MockSearchPhaseContext.java
@@ -182,6 +182,14 @@ public void addReleasable(Releasable releasable) {
// Noop
}
+ /**
+ * Set the resource usage info for this phase
+ */
+ @Override
+ public void setPhaseResourceUsages() {
+ // Noop
+ }
+
@Override
public void execute(Runnable command) {
command.run();
diff --git a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java
index 35e90ff662b19..8fe2d9af217d5 100644
--- a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java
+++ b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java
@@ -162,7 +162,7 @@ public void testSkipSearchShards() throws InterruptedException {
new ArraySearchPhaseResults<>(shardsIter.size()),
request.getMaxConcurrentShardRequests(),
SearchResponse.Clusters.EMPTY,
- new SearchRequestContext(searchRequestOperationsListener, request),
+ new SearchRequestContext(searchRequestOperationsListener, request, () -> null),
NoopTracer.INSTANCE
) {
@@ -287,7 +287,7 @@ public void testLimitConcurrentShardRequests() throws InterruptedException {
new ArraySearchPhaseResults<>(shardsIter.size()),
request.getMaxConcurrentShardRequests(),
SearchResponse.Clusters.EMPTY,
- new SearchRequestContext(searchRequestOperationsListener, request),
+ new SearchRequestContext(searchRequestOperationsListener, request, () -> null),
NoopTracer.INSTANCE
) {
@@ -409,7 +409,8 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
- request
+ request,
+ () -> null
),
NoopTracer.INSTANCE
) {
@@ -537,7 +538,8 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
- request
+ request,
+ () -> null
),
NoopTracer.INSTANCE
) {
@@ -657,7 +659,7 @@ public void testAllowPartialResults() throws InterruptedException {
new ArraySearchPhaseResults<>(shardsIter.size()),
request.getMaxConcurrentShardRequests(),
SearchResponse.Clusters.EMPTY,
- new SearchRequestContext(searchRequestOperationsListener, request),
+ new SearchRequestContext(searchRequestOperationsListener, request, () -> null),
NoopTracer.INSTANCE
) {
@Override
diff --git a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java
index aefbbe80d5fa1..f6a06a51c7b43 100644
--- a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java
+++ b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java
@@ -240,7 +240,8 @@ public void sendExecuteQuery(
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
- searchRequest
+ searchRequest,
+ () -> null
),
NoopTracer.INSTANCE
) {
diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerSupport.java b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerSupport.java
index 0f737e00478cb..fdac91a0e3124 100644
--- a/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerSupport.java
+++ b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerSupport.java
@@ -25,7 +25,8 @@ default void onPhaseEnd(SearchRequestOperationsListener listener, SearchPhaseCon
context,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
}
diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestSlowLogTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestSlowLogTests.java
index 91a2552ac3f04..453fc6cd8a74c 100644
--- a/server/src/test/java/org/opensearch/action/search/SearchRequestSlowLogTests.java
+++ b/server/src/test/java/org/opensearch/action/search/SearchRequestSlowLogTests.java
@@ -178,7 +178,8 @@ public void testConcurrentOnRequestEnd() throws InterruptedException {
for (int i = 0; i < numRequests; i++) {
SearchRequestContext searchRequestContext = new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(searchListenersList, logger),
- searchRequest
+ searchRequest,
+ () -> null
);
searchRequestContext.setAbsoluteStartNanos((i < numRequestsLogged) ? 0 : System.nanoTime());
searchRequestContexts.add(searchRequestContext);
@@ -209,7 +210,8 @@ public void testSearchRequestSlowLogHasJsonFields_EmptySearchRequestContext() th
SearchPhaseContext searchPhaseContext = new MockSearchPhaseContext(1, searchRequest);
SearchRequestContext searchRequestContext = new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- searchRequest
+ searchRequest,
+ () -> null
);
SearchRequestSlowLog.SearchRequestSlowLogMessage p = new SearchRequestSlowLog.SearchRequestSlowLogMessage(
searchPhaseContext,
@@ -233,7 +235,8 @@ public void testSearchRequestSlowLogHasJsonFields_NotEmptySearchRequestContext()
SearchPhaseContext searchPhaseContext = new MockSearchPhaseContext(1, searchRequest);
SearchRequestContext searchRequestContext = new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- searchRequest
+ searchRequest,
+ () -> null
);
searchRequestContext.updatePhaseTookMap(SearchPhaseName.FETCH.getName(), 10L);
searchRequestContext.updatePhaseTookMap(SearchPhaseName.QUERY.getName(), 50L);
@@ -262,7 +265,8 @@ public void testSearchRequestSlowLogHasJsonFields_PartialContext() throws IOExce
SearchPhaseContext searchPhaseContext = new MockSearchPhaseContext(1, searchRequest);
SearchRequestContext searchRequestContext = new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- searchRequest
+ searchRequest,
+ () -> null
);
searchRequestContext.updatePhaseTookMap(SearchPhaseName.FETCH.getName(), 10L);
searchRequestContext.updatePhaseTookMap(SearchPhaseName.QUERY.getName(), 50L);
@@ -291,7 +295,8 @@ public void testSearchRequestSlowLogSearchContextPrinterToLog() throws IOExcepti
SearchPhaseContext searchPhaseContext = new MockSearchPhaseContext(1, searchRequest);
SearchRequestContext searchRequestContext = new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- searchRequest
+ searchRequest,
+ () -> null
);
searchRequestContext.updatePhaseTookMap(SearchPhaseName.FETCH.getName(), 10L);
searchRequestContext.updatePhaseTookMap(SearchPhaseName.QUERY.getName(), 50L);
diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java
index fb9b26e3f3ad1..1af3eb2738a58 100644
--- a/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java
+++ b/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java
@@ -60,7 +60,8 @@ public void testSearchRequestStats() {
ctx,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
assertEquals(0, testRequestStats.getPhaseCurrent(searchPhaseName));
@@ -120,7 +121,8 @@ public void testSearchRequestStatsOnPhaseEndConcurrently() throws InterruptedExc
ctx,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
countDownLatch.countDown();
diff --git a/server/src/test/java/org/opensearch/action/search/SearchResponseMergerTests.java b/server/src/test/java/org/opensearch/action/search/SearchResponseMergerTests.java
index ce4d5ca4f7091..0eefa413c1864 100644
--- a/server/src/test/java/org/opensearch/action/search/SearchResponseMergerTests.java
+++ b/server/src/test/java/org/opensearch/action/search/SearchResponseMergerTests.java
@@ -137,7 +137,8 @@ public void testMergeTookInMillis() throws InterruptedException {
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
assertEquals(TimeUnit.NANOSECONDS.toMillis(currentRelativeTime), searchResponse.getTook().millis());
@@ -195,7 +196,8 @@ public void testMergeShardFailures() throws InterruptedException {
clusters,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
assertSame(clusters, mergedResponse.getClusters());
@@ -252,7 +254,8 @@ public void testMergeShardFailuresNullShardTarget() throws InterruptedException
clusters,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
assertSame(clusters, mergedResponse.getClusters());
@@ -304,7 +307,8 @@ public void testMergeShardFailuresNullShardId() throws InterruptedException {
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
).getShardFailures();
assertThat(Arrays.asList(shardFailures), containsInAnyOrder(expectedFailures.toArray(ShardSearchFailure.EMPTY_ARRAY)));
@@ -344,7 +348,8 @@ public void testMergeProfileResults() throws InterruptedException {
clusters,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
assertSame(clusters, mergedResponse.getClusters());
@@ -412,7 +417,8 @@ public void testMergeCompletionSuggestions() throws InterruptedException {
clusters,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
assertSame(clusters, mergedResponse.getClusters());
@@ -490,7 +496,8 @@ public void testMergeCompletionSuggestionsTieBreak() throws InterruptedException
clusters,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
assertSame(clusters, mergedResponse.getClusters());
@@ -570,7 +577,8 @@ public void testMergeAggs() throws InterruptedException {
clusters,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
assertSame(clusters, mergedResponse.getClusters());
@@ -733,7 +741,8 @@ public void testMergeSearchHits() throws InterruptedException {
clusters,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
@@ -799,7 +808,8 @@ public void testMergeNoResponsesAdded() {
clusters,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
assertSame(clusters, response.getClusters());
@@ -878,7 +888,8 @@ public void testMergeEmptySearchHitsWithNonEmpty() {
clusters,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
assertEquals(10, mergedResponse.getHits().getTotalHits().value);
@@ -926,7 +937,8 @@ public void testMergeOnlyEmptyHits() {
clusters,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- new SearchRequest()
+ new SearchRequest(),
+ () -> null
)
);
assertEquals(expectedTotalHits, mergedResponse.getHits().getTotalHits());
diff --git a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java
index da19c839f3826..84955d01a59ce 100644
--- a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java
+++ b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java
@@ -487,7 +487,8 @@ public void testCCSRemoteReduceMergeFails() throws Exception {
(r, l) -> setOnce.set(Tuple.tuple(r, l)),
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- searchRequest
+ searchRequest,
+ () -> null
)
);
if (localIndices == null) {
@@ -549,7 +550,8 @@ public void testCCSRemoteReduce() throws Exception {
(r, l) -> setOnce.set(Tuple.tuple(r, l)),
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- searchRequest
+ searchRequest,
+ () -> null
)
);
if (localIndices == null) {
@@ -590,7 +592,8 @@ public void testCCSRemoteReduce() throws Exception {
(r, l) -> setOnce.set(Tuple.tuple(r, l)),
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- searchRequest
+ searchRequest,
+ () -> null
)
);
if (localIndices == null) {
@@ -652,7 +655,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti
(r, l) -> setOnce.set(Tuple.tuple(r, l)),
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- searchRequest
+ searchRequest,
+ () -> null
)
);
if (localIndices == null) {
@@ -696,7 +700,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti
(r, l) -> setOnce.set(Tuple.tuple(r, l)),
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- searchRequest
+ searchRequest,
+ () -> null
)
);
if (localIndices == null) {
@@ -751,7 +756,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti
(r, l) -> setOnce.set(Tuple.tuple(r, l)),
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
- searchRequest
+ searchRequest,
+ () -> null
)
);
if (localIndices == null) {
diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java
index 86de008b5dee5..622507f885814 100644
--- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java
+++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java
@@ -2291,7 +2291,8 @@ public void onFailure(final Exception e) {
new FetchPhase(Collections.emptyList()),
responseCollectorService,
new NoneCircuitBreakerService(),
- null
+ null,
+ new TaskResourceTrackingService(settings, clusterSettings, threadPool)
);
SearchPhaseController searchPhaseController = new SearchPhaseController(
writableRegistry(),
@@ -2326,7 +2327,8 @@ public void onFailure(final Exception e) {
),
NoopMetricsRegistry.INSTANCE,
searchRequestOperationsCompositeListenerFactory,
- NoopTracer.INSTANCE
+ NoopTracer.INSTANCE,
+ new TaskResourceTrackingService(settings, clusterSettings, threadPool)
)
);
actions.put(
diff --git a/server/src/test/java/org/opensearch/tasks/TaskResourceInfoTests.java b/server/src/test/java/org/opensearch/tasks/TaskResourceInfoTests.java
new file mode 100644
index 0000000000000..e0bfb8710bbaa
--- /dev/null
+++ b/server/src/test/java/org/opensearch/tasks/TaskResourceInfoTests.java
@@ -0,0 +1,106 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.tasks;
+
+import org.opensearch.common.io.stream.BytesStreamOutput;
+import org.opensearch.core.common.bytes.BytesReference;
+import org.opensearch.core.common.io.stream.StreamInput;
+import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo;
+import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage;
+import org.opensearch.core.xcontent.MediaTypeRegistry;
+import org.opensearch.core.xcontent.ToXContent;
+import org.opensearch.core.xcontent.XContentBuilder;
+import org.opensearch.test.OpenSearchTestCase;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Locale;
+
+/**
+ * Test cases for TaskResourceInfo
+ */
+public class TaskResourceInfoTests extends OpenSearchTestCase {
+ private final Long cpuUsage = randomNonNegativeLong();
+ private final Long memoryUsage = randomNonNegativeLong();
+ private final String action = randomAlphaOfLengthBetween(1, 10);
+ private final Long taskId = randomNonNegativeLong();
+ private final Long parentTaskId = randomNonNegativeLong();
+ private final String nodeId = randomAlphaOfLengthBetween(1, 10);
+ private TaskResourceInfo taskResourceInfo;
+ private TaskResourceUsage taskResourceUsage;
+
+ @Before
+ public void setUpVariables() {
+ taskResourceUsage = new TaskResourceUsage(cpuUsage, memoryUsage);
+ taskResourceInfo = new TaskResourceInfo(action, taskId, parentTaskId, nodeId, taskResourceUsage);
+ }
+
+ public void testGetters() {
+ assertEquals(action, taskResourceInfo.getAction());
+ assertEquals(taskId.longValue(), taskResourceInfo.getTaskId());
+ assertEquals(parentTaskId.longValue(), taskResourceInfo.getParentTaskId());
+ assertEquals(nodeId, taskResourceInfo.getNodeId());
+ assertEquals(taskResourceUsage, taskResourceInfo.getTaskResourceUsage());
+ }
+
+ public void testEqualsAndHashCode() {
+ TaskResourceInfo taskResourceInfoCopy = new TaskResourceInfo(action, taskId, parentTaskId, nodeId, taskResourceUsage);
+ assertEquals(taskResourceInfo, taskResourceInfoCopy);
+ assertEquals(taskResourceInfo.hashCode(), taskResourceInfoCopy.hashCode());
+ TaskResourceInfo differentTaskResourceInfo = new TaskResourceInfo(
+ "differentAction",
+ taskId,
+ parentTaskId,
+ nodeId,
+ taskResourceUsage
+ );
+ assertNotEquals(taskResourceInfo, differentTaskResourceInfo);
+ assertNotEquals(taskResourceInfo.hashCode(), differentTaskResourceInfo.hashCode());
+ }
+
+ public void testSerialization() throws IOException {
+ BytesStreamOutput output = new BytesStreamOutput();
+ taskResourceInfo.writeTo(output);
+ StreamInput input = StreamInput.wrap(output.bytes().toBytesRef().bytes);
+ TaskResourceInfo deserializedTaskResourceInfo = TaskResourceInfo.readFromStream(input);
+ assertEquals(taskResourceInfo, deserializedTaskResourceInfo);
+ }
+
+ public void testToString() {
+ String expectedString = String.format(
+ Locale.ROOT,
+ "{\"action\":\"%s\",\"taskId\":%s,\"parentTaskId\":%s,\"nodeId\":\"%s\",\"taskResourceUsage\":{\"cpu_time_in_nanos\":%s,\"memory_in_bytes\":%s}}",
+ action,
+ taskId,
+ parentTaskId,
+ nodeId,
+ taskResourceUsage.getCpuTimeInNanos(),
+ taskResourceUsage.getMemoryInBytes()
+ );
+ assertTrue(expectedString.equals(taskResourceInfo.toString()));
+ }
+
+ public void testToXContent() throws IOException {
+ char[] expectedXcontent = String.format(
+ Locale.ROOT,
+ "{\"action\":\"%s\",\"taskId\":%s,\"parentTaskId\":%s,\"nodeId\":\"%s\",\"taskResourceUsage\":{\"cpu_time_in_nanos\":%s,\"memory_in_bytes\":%s}}",
+ action,
+ taskId,
+ parentTaskId,
+ nodeId,
+ taskResourceUsage.getCpuTimeInNanos(),
+ taskResourceUsage.getMemoryInBytes()
+ ).toCharArray();
+
+ XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON);
+ char[] xContent = BytesReference.bytes(taskResourceInfo.toXContent(builder, ToXContent.EMPTY_PARAMS)).utf8ToString().toCharArray();
+ assertEquals(Arrays.hashCode(expectedXcontent), Arrays.hashCode(xContent));
+ }
+}
diff --git a/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java
index 45d438f8d04c9..0c19c331e1510 100644
--- a/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java
+++ b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java
@@ -9,11 +9,15 @@
package org.opensearch.tasks;
import org.opensearch.action.admin.cluster.node.tasks.TransportTasksActionTests;
+import org.opensearch.action.search.SearchShardTask;
import org.opensearch.action.search.SearchTask;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.tasks.TaskId;
+import org.opensearch.core.tasks.resourcetracker.ResourceStatsType;
+import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric;
+import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo;
import org.opensearch.core.tasks.resourcetracker.ThreadResourceInfo;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
@@ -31,6 +35,7 @@
import static org.opensearch.core.tasks.resourcetracker.ResourceStats.CPU;
import static org.opensearch.core.tasks.resourcetracker.ResourceStats.MEMORY;
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;
+import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE;
public class TaskResourceTrackingServiceTests extends OpenSearchTestCase {
@@ -142,6 +147,36 @@ public void testStartingTrackingHandlesMultipleThreadsPerTask() throws Interrupt
assertEquals(numTasks, numExecutions);
}
+ public void testWriteTaskResourceUsage() {
+ SearchShardTask task = new SearchShardTask(1, "test", "test", "task", TaskId.EMPTY_TASK_ID, new HashMap<>());
+ taskResourceTrackingService.setTaskResourceTrackingEnabled(true);
+ taskResourceTrackingService.startTracking(task);
+ task.startThreadResourceTracking(
+ Thread.currentThread().getId(),
+ ResourceStatsType.WORKER_STATS,
+ new ResourceUsageMetric(CPU, 100),
+ new ResourceUsageMetric(MEMORY, 100)
+ );
+ taskResourceTrackingService.writeTaskResourceUsage(task, "node_1");
+ Map> headers = threadPool.getThreadContext().getResponseHeaders();
+ assertEquals(1, headers.size());
+ assertTrue(headers.containsKey(TASK_RESOURCE_USAGE));
+ }
+
+ public void testGetTaskResourceUsageFromThreadContext() {
+ String taskResourceUsageJson =
+ "{\"action\":\"testAction\",\"taskId\":1,\"parentTaskId\":2,\"nodeId\":\"nodeId\",\"taskResourceUsage\":{\"cpu_time_in_nanos\":1000,\"memory_in_bytes\":2000}}";
+ threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceUsageJson);
+ TaskResourceInfo result = taskResourceTrackingService.getTaskResourceUsageFromThreadContext();
+ assertNotNull(result);
+ assertEquals("testAction", result.getAction());
+ assertEquals(1L, result.getTaskId());
+ assertEquals(2L, result.getParentTaskId());
+ assertEquals("nodeId", result.getNodeId());
+ assertEquals(1000L, result.getTaskResourceUsage().getCpuTimeInNanos());
+ assertEquals(2000L, result.getTaskResourceUsage().getMemoryInBytes());
+ }
+
private void verifyThreadContextFixedHeaders(String key, String value) {
assertEquals(threadPool.getThreadContext().getHeader(key), value);
assertEquals(threadPool.getThreadContext().getTransient(key), value);
diff --git a/test/framework/src/main/java/org/opensearch/node/MockNode.java b/test/framework/src/main/java/org/opensearch/node/MockNode.java
index e6c7e21d5b3ea..19c65ec169d3c 100644
--- a/test/framework/src/main/java/org/opensearch/node/MockNode.java
+++ b/test/framework/src/main/java/org/opensearch/node/MockNode.java
@@ -60,6 +60,7 @@
import org.opensearch.search.SearchService;
import org.opensearch.search.fetch.FetchPhase;
import org.opensearch.search.query.QueryPhase;
+import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.test.MockHttpTransport;
import org.opensearch.test.transport.MockTransportService;
@@ -155,7 +156,8 @@ protected SearchService newSearchService(
FetchPhase fetchPhase,
ResponseCollectorService responseCollectorService,
CircuitBreakerService circuitBreakerService,
- Executor indexSearcherExecutor
+ Executor indexSearcherExecutor,
+ TaskResourceTrackingService taskResourceTrackingService
) {
if (getPluginsService().filterPlugins(MockSearchService.TestPlugin.class).isEmpty()) {
return super.newSearchService(
@@ -168,7 +170,8 @@ protected SearchService newSearchService(
fetchPhase,
responseCollectorService,
circuitBreakerService,
- indexSearcherExecutor
+ indexSearcherExecutor,
+ taskResourceTrackingService
);
}
return new MockSearchService(
@@ -180,7 +183,8 @@ protected SearchService newSearchService(
queryPhase,
fetchPhase,
circuitBreakerService,
- indexSearcherExecutor
+ indexSearcherExecutor,
+ taskResourceTrackingService
);
}
diff --git a/test/framework/src/main/java/org/opensearch/search/MockSearchService.java b/test/framework/src/main/java/org/opensearch/search/MockSearchService.java
index a0bbcb7be05f9..6c9ace06c8219 100644
--- a/test/framework/src/main/java/org/opensearch/search/MockSearchService.java
+++ b/test/framework/src/main/java/org/opensearch/search/MockSearchService.java
@@ -42,6 +42,7 @@
import org.opensearch.search.fetch.FetchPhase;
import org.opensearch.search.internal.ReaderContext;
import org.opensearch.search.query.QueryPhase;
+import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.threadpool.ThreadPool;
import java.util.HashMap;
@@ -96,7 +97,8 @@ public MockSearchService(
QueryPhase queryPhase,
FetchPhase fetchPhase,
CircuitBreakerService circuitBreakerService,
- Executor indexSearcherExecutor
+ Executor indexSearcherExecutor,
+ TaskResourceTrackingService taskResourceTrackingService
) {
super(
clusterService,
@@ -108,7 +110,8 @@ public MockSearchService(
fetchPhase,
null,
circuitBreakerService,
- indexSearcherExecutor
+ indexSearcherExecutor,
+ taskResourceTrackingService
);
}