From 69629ff2b5801f3ea16cdd3585b9a322febcf032 Mon Sep 17 00:00:00 2001 From: Chenyang Ji Date: Tue, 4 Jun 2024 15:23:51 -0700 Subject: [PATCH] move resource usages interactions into TaskResourceTrackingService Signed-off-by: Chenyang Ji --- CHANGELOG.md | 2 +- .../search/AbstractSearchAsyncAction.java | 3 +- .../action/search/SearchRequestContext.java | 44 +++------ .../action/search/TransportSearchAction.java | 19 ++-- .../org/opensearch/search/SearchService.java | 70 ++----------- .../tasks/TaskResourceTrackingService.java | 97 ++++++++++++++++++- .../AbstractSearchAsyncActionTests.java | 2 +- .../snapshots/SnapshotResiliencyTests.java | 5 +- .../TaskResourceTrackingServiceTests.java | 35 +++++++ 9 files changed, 167 insertions(+), 110 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 549ead2b587cf..4b714d882ea65 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Remote Store] Add support to disable flush based on translog reader count ([#14027](https://github.com/opensearch-project/OpenSearch/pull/14027)) - [Query Insights] Add exporter support for top n queries ([#12982](https://github.com/opensearch-project/OpenSearch/pull/12982)) - [Query Insights] Add X-Opaque-Id to search request metadata for top n queries ([#13374](https://github.com/opensearch-project/OpenSearch/pull/13374)) -- Add support for query-level resource usage tracking ([#13172](https://github.com/opensearch-project/OpenSearch/pull/13172)) +- Add support for query level resource usage tracking ([#13172](https://github.com/opensearch-project/OpenSearch/pull/13172)) ### Dependencies - Bump `com.github.spullara.mustache.java:compiler` from 0.9.10 to 0.9.13 ([#13329](https://github.com/opensearch-project/OpenSearch/pull/13329), [#13559](https://github.com/opensearch-project/OpenSearch/pull/13559)) diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index af84422df7067..f0fc05c595d6f 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -51,6 +51,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ShardOperationFailedException; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.internal.AliasFilter; @@ -628,7 +629,7 @@ protected void onShardResult(Result result, SearchShardIterator shardIt) { } public void setPhaseResourceUsages() { - String taskResourceUsage = searchRequestContext.getTaskResourceUsageSupplier().get(); + TaskResourceInfo taskResourceUsage = searchRequestContext.getTaskResourceUsageSupplier().get(); searchRequestContext.recordPhaseResourceUsage(taskResourceUsage); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java b/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java index 45bb10b989ca7..111d9c64550b3 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java @@ -12,21 +12,15 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TotalHits; import org.opensearch.common.annotation.InternalApi; -import org.opensearch.common.xcontent.XContentHelper; -import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; -import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; import java.util.ArrayList; import java.util.EnumMap; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.LinkedBlockingQueue; import java.util.function.Supplier; /** @@ -44,20 +38,20 @@ public class SearchRequestContext { private final EnumMap shardStats; private final SearchRequest searchRequest; - private final List phaseResourceUsage; - private final Supplier taskResourceUsageSupplier; + private final LinkedBlockingQueue phaseResourceUsage; + private final Supplier taskResourceUsageSupplier; SearchRequestContext( final SearchRequestOperationsListener searchRequestOperationsListener, final SearchRequest searchRequest, - final Supplier taskResourceUsageSupplier + final Supplier taskResourceUsageSupplier ) { this.searchRequestOperationsListener = searchRequestOperationsListener; this.absoluteStartNanos = System.nanoTime(); this.phaseTookMap = new HashMap<>(); this.shardStats = new EnumMap<>(ShardStatsFieldNames.class); this.searchRequest = searchRequest; - this.phaseResourceUsage = new ArrayList<>(); + this.phaseResourceUsage = new LinkedBlockingQueue<>(); this.taskResourceUsageSupplier = taskResourceUsageSupplier; } @@ -130,32 +124,22 @@ String formattedShardStats() { } } - public Supplier getTaskResourceUsageSupplier() { + public Supplier getTaskResourceUsageSupplier() { return taskResourceUsageSupplier; } - public SearchRequest getRequest() { - return searchRequest; - } - - public void recordPhaseResourceUsage(String usage) { - try { - if (usage != null && !usage.isEmpty()) { - XContentParser parser = XContentHelper.createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - new BytesArray(usage), - MediaTypeRegistry.JSON - ); - this.phaseResourceUsage.add(TaskResourceInfo.PARSER.apply(parser, null)); - } - } catch (IOException e) { - logger.debug("fail to parse phase resource usages: ", e); + public void recordPhaseResourceUsage(TaskResourceInfo usage) { + if (usage != null) { + this.phaseResourceUsage.add(usage); } } public List getPhaseResourceUsage() { - return phaseResourceUsage; + return new ArrayList<>(phaseResourceUsage); + } + + public SearchRequest getRequest() { + return searchRequest; } } diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 09da8d03a0aeb..6e380775355a2 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -87,6 +87,7 @@ import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.telemetry.metrics.MetricsRegistry; import org.opensearch.telemetry.tracing.Span; import org.opensearch.telemetry.tracing.SpanBuilder; @@ -125,7 +126,6 @@ import static org.opensearch.action.search.SearchType.DFS_QUERY_THEN_FETCH; import static org.opensearch.action.search.SearchType.QUERY_THEN_FETCH; import static org.opensearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort; -import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE; /** * Perform search action @@ -187,6 +187,7 @@ public class TransportSearchAction extends HandledTransportAction) SearchRequest::new); this.client = client; @@ -225,6 +227,7 @@ public TransportSearchAction( clusterService.getClusterSettings() .addSettingsUpdateConsumer(SEARCH_QUERY_METRICS_ENABLED_SETTING, this::setSearchQueryMetricsEnabled); this.tracer = tracer; + this.taskResourceTrackingService = taskResourceTrackingService; } private void setSearchQueryMetricsEnabled(boolean searchQueryMetricsEnabled) { @@ -452,14 +455,10 @@ private void executeRequest( logger, TraceableSearchRequestOperationsListener.create(tracer, requestSpan) ); - SearchRequestContext searchRequestContext = new SearchRequestContext(requestOperationsListeners, originalSearchRequest, () -> { - List taskResourceUsages = threadPool.getThreadContext().getResponseHeaders().get(TASK_RESOURCE_USAGE); - if (taskResourceUsages != null && taskResourceUsages.size() > 0) { - return taskResourceUsages.get(0); - } - return null; - } - + SearchRequestContext searchRequestContext = new SearchRequestContext( + requestOperationsListeners, + originalSearchRequest, + taskResourceTrackingService::getTaskResourceUsageFromThreadContext ); searchRequestContext.getSearchRequestOperationsListener().onRequestStart(searchRequestContext); diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index d5c2b13eb5041..45f111d889522 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -73,13 +73,6 @@ import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.indices.breaker.CircuitBreakerService; -import org.opensearch.core.tasks.resourcetracker.ResourceStats; -import org.opensearch.core.tasks.resourcetracker.ResourceStatsType; -import org.opensearch.core.tasks.resourcetracker.ResourceUsageInfo; -import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric; -import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; -import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage; -import org.opensearch.core.tasks.resourcetracker.ThreadResourceInfo; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.IndexService; import org.opensearch.index.IndexSettings; @@ -168,7 +161,6 @@ import static org.opensearch.common.unit.TimeValue.timeValueHours; import static org.opensearch.common.unit.TimeValue.timeValueMillis; import static org.opensearch.common.unit.TimeValue.timeValueMinutes; -import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE; /** * The main search service @@ -571,7 +563,7 @@ private DfsSearchResult executeDfsPhase(ShardSearchRequest request, SearchShardT processFailure(readerContext, e); throw e; } finally { - writeTaskResourceUsage(task); + taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId()); } } @@ -675,7 +667,7 @@ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchSh processFailure(readerContext, e); throw e; } finally { - writeTaskResourceUsage(task); + taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId()); } } @@ -722,7 +714,7 @@ public void executeQueryPhase( // we handle the failure in the failure listener below throw e; } finally { - writeTaskResourceUsage(task); + taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId()); } }, wrapFailureListener(listener, readerContext, markAsUsed)); } @@ -756,7 +748,7 @@ public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task, // we handle the failure in the failure listener below throw e; } finally { - writeTaskResourceUsage(task); + taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId()); } }, wrapFailureListener(listener, readerContext, markAsUsed)); } @@ -807,7 +799,7 @@ public void executeFetchPhase( // we handle the failure in the failure listener below throw e; } finally { - writeTaskResourceUsage(task); + taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId()); } }, wrapFailureListener(listener, readerContext, markAsUsed)); } @@ -839,7 +831,7 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A // we handle the failure in the failure listener below throw e; } finally { - writeTaskResourceUsage(task); + taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId()); } }, wrapFailureListener(listener, readerContext, markAsUsed)); } @@ -1139,56 +1131,6 @@ private DefaultSearchContext createSearchContext(ReaderContext reader, ShardSear return searchContext; } - private void writeTaskResourceUsage(SearchShardTask task) { - try { - // Get resource usages from when the task started - ThreadResourceInfo threadResourceInfo = task.getActiveThreadResourceInfo( - Thread.currentThread().getId(), - ResourceStatsType.WORKER_STATS - ); - if (threadResourceInfo == null) { - return; - } - Map startValues = threadResourceInfo.getResourceUsageInfo().getStatsInfo(); - if (!(startValues.containsKey(ResourceStats.CPU) && startValues.containsKey(ResourceStats.MEMORY))) { - return; - } - // Get current resource usages - ResourceUsageMetric[] endValues = taskResourceTrackingService.getResourceUsageMetricsForThread(Thread.currentThread().getId()); - long cpu = -1, mem = -1; - for (ResourceUsageMetric endValue : endValues) { - if (endValue.getStats() == ResourceStats.MEMORY) { - mem = endValue.getValue(); - } else if (endValue.getStats() == ResourceStats.CPU) { - cpu = endValue.getValue(); - } - } - if (cpu == -1 || mem == -1) { - logger.debug("Invalid resource usage value, cpu [{}], memory [{}]: ", cpu, mem); - return; - } - - // Build task resource usage info - TaskResourceInfo taskResourceInfo = new TaskResourceInfo.Builder().setAction(task.getAction()) - .setTaskId(task.getId()) - .setParentTaskId(task.getParentTaskId().getId()) - .setNodeId(clusterService.localNode().getId()) - .setTaskResourceUsage( - new TaskResourceUsage( - cpu - startValues.get(ResourceStats.CPU).getStartValue(), - mem - startValues.get(ResourceStats.MEMORY).getStartValue() - ) - ) - .build(); - - // Remove the existing TASK_RESOURCE_USAGE header since it would have come from an earlier phase in the same request. - threadPool.getThreadContext().removeResponseHeader(TASK_RESOURCE_USAGE); - threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString()); - } catch (Exception e) { - logger.debug("Error during writing task resource usage: ", e); - } - } - private void freeAllContextForIndex(Index index) { assert index != null; for (ReaderContext ctx : activeReaders.values()) { diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index 59e719a3c3250..564eff6c10df6 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -14,6 +14,7 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.ExceptionsHelper; +import org.opensearch.action.search.SearchShardTask; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.ClusterSettings; @@ -22,12 +23,23 @@ import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.common.util.concurrent.ConcurrentMapLong; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.tasks.resourcetracker.ResourceStats; +import org.opensearch.core.tasks.resourcetracker.ResourceStatsType; +import org.opensearch.core.tasks.resourcetracker.ResourceUsageInfo; import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric; +import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; +import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage; import org.opensearch.core.tasks.resourcetracker.ThreadResourceInfo; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.threadpool.ThreadPool; +import java.io.IOException; import java.lang.management.ManagementFactory; import java.util.ArrayList; import java.util.Collections; @@ -212,7 +224,7 @@ public Map getResourceAwareTasks() { return Collections.unmodifiableMap(resourceAwareTasks); } - public ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { + private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric( ResourceStats.MEMORY, threadMXBean.getThreadAllocatedBytes(threadId) @@ -262,6 +274,89 @@ private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) { return storedContext; } + /** + * Get the current task level resource usage. + * + * @param task {@link SearchShardTask} + * @param nodeId the local nodeId + */ + public void writeTaskResourceUsage(SearchShardTask task, String nodeId) { + try { + // Get resource usages from when the task started + ThreadResourceInfo threadResourceInfo = task.getActiveThreadResourceInfo( + Thread.currentThread().getId(), + ResourceStatsType.WORKER_STATS + ); + if (threadResourceInfo == null) { + return; + } + Map startValues = threadResourceInfo.getResourceUsageInfo().getStatsInfo(); + if (!(startValues.containsKey(ResourceStats.CPU) && startValues.containsKey(ResourceStats.MEMORY))) { + return; + } + // Get current resource usages + ResourceUsageMetric[] endValues = getResourceUsageMetricsForThread(Thread.currentThread().getId()); + long cpu = -1, mem = -1; + for (ResourceUsageMetric endValue : endValues) { + if (endValue.getStats() == ResourceStats.MEMORY) { + mem = endValue.getValue(); + } else if (endValue.getStats() == ResourceStats.CPU) { + cpu = endValue.getValue(); + } + } + if (cpu == -1 || mem == -1) { + logger.debug("Invalid resource usage value, cpu [{}], memory [{}]: ", cpu, mem); + return; + } + + // Build task resource usage info + TaskResourceInfo taskResourceInfo = new TaskResourceInfo.Builder().setAction(task.getAction()) + .setTaskId(task.getId()) + .setParentTaskId(task.getParentTaskId().getId()) + .setNodeId(nodeId) + .setTaskResourceUsage( + new TaskResourceUsage( + cpu - startValues.get(ResourceStats.CPU).getStartValue(), + mem - startValues.get(ResourceStats.MEMORY).getStartValue() + ) + ) + .build(); + // Remove the existing TASK_RESOURCE_USAGE header since it would have come from an earlier phase in the same request. + synchronized (this) { + threadPool.getThreadContext().removeResponseHeader(TASK_RESOURCE_USAGE); + threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString()); + } + } catch (Exception e) { + logger.debug("Error during writing task resource usage: ", e); + } + } + + /** + * Get the task resource usages from {@link ThreadContext} + * + * @return {@link TaskResourceInfo} + */ + public TaskResourceInfo getTaskResourceUsageFromThreadContext() { + List taskResourceUsages = threadPool.getThreadContext().getResponseHeaders().get(TASK_RESOURCE_USAGE); + if (taskResourceUsages != null && taskResourceUsages.size() > 0) { + String usage = taskResourceUsages.get(0); + try { + if (usage != null && !usage.isEmpty()) { + XContentParser parser = XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(usage), + MediaTypeRegistry.JSON + ); + return TaskResourceInfo.PARSER.apply(parser, null); + } + } catch (IOException e) { + logger.debug("fail to parse phase resource usages: ", e); + } + } + return null; + } + /** * Listener that gets invoked when a task execution completes. */ diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index 730f0569f8bc5..27336e86e52b0 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -201,7 +201,7 @@ private AbstractSearchAsyncAction createAction( new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), request, - () -> "" + () -> null ), NoopTracer.INSTANCE ) { diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index bbd1bcdc35c82..622507f885814 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -2292,7 +2292,7 @@ public void onFailure(final Exception e) { responseCollectorService, new NoneCircuitBreakerService(), null, - null + new TaskResourceTrackingService(settings, clusterSettings, threadPool) ); SearchPhaseController searchPhaseController = new SearchPhaseController( writableRegistry(), @@ -2327,7 +2327,8 @@ public void onFailure(final Exception e) { ), NoopMetricsRegistry.INSTANCE, searchRequestOperationsCompositeListenerFactory, - NoopTracer.INSTANCE + NoopTracer.INSTANCE, + new TaskResourceTrackingService(settings, clusterSettings, threadPool) ) ); actions.put( diff --git a/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java index 45d438f8d04c9..0c19c331e1510 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java @@ -9,11 +9,15 @@ package org.opensearch.tasks; import org.opensearch.action.admin.cluster.node.tasks.TransportTasksActionTests; +import org.opensearch.action.search.SearchShardTask; import org.opensearch.action.search.SearchTask; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.tasks.TaskId; +import org.opensearch.core.tasks.resourcetracker.ResourceStatsType; +import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric; +import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; import org.opensearch.core.tasks.resourcetracker.ThreadResourceInfo; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; @@ -31,6 +35,7 @@ import static org.opensearch.core.tasks.resourcetracker.ResourceStats.CPU; import static org.opensearch.core.tasks.resourcetracker.ResourceStats.MEMORY; import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE; public class TaskResourceTrackingServiceTests extends OpenSearchTestCase { @@ -142,6 +147,36 @@ public void testStartingTrackingHandlesMultipleThreadsPerTask() throws Interrupt assertEquals(numTasks, numExecutions); } + public void testWriteTaskResourceUsage() { + SearchShardTask task = new SearchShardTask(1, "test", "test", "task", TaskId.EMPTY_TASK_ID, new HashMap<>()); + taskResourceTrackingService.setTaskResourceTrackingEnabled(true); + taskResourceTrackingService.startTracking(task); + task.startThreadResourceTracking( + Thread.currentThread().getId(), + ResourceStatsType.WORKER_STATS, + new ResourceUsageMetric(CPU, 100), + new ResourceUsageMetric(MEMORY, 100) + ); + taskResourceTrackingService.writeTaskResourceUsage(task, "node_1"); + Map> headers = threadPool.getThreadContext().getResponseHeaders(); + assertEquals(1, headers.size()); + assertTrue(headers.containsKey(TASK_RESOURCE_USAGE)); + } + + public void testGetTaskResourceUsageFromThreadContext() { + String taskResourceUsageJson = + "{\"action\":\"testAction\",\"taskId\":1,\"parentTaskId\":2,\"nodeId\":\"nodeId\",\"taskResourceUsage\":{\"cpu_time_in_nanos\":1000,\"memory_in_bytes\":2000}}"; + threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceUsageJson); + TaskResourceInfo result = taskResourceTrackingService.getTaskResourceUsageFromThreadContext(); + assertNotNull(result); + assertEquals("testAction", result.getAction()); + assertEquals(1L, result.getTaskId()); + assertEquals(2L, result.getParentTaskId()); + assertEquals("nodeId", result.getNodeId()); + assertEquals(1000L, result.getTaskResourceUsage().getCpuTimeInNanos()); + assertEquals(2000L, result.getTaskResourceUsage().getMemoryInBytes()); + } + private void verifyThreadContextFixedHeaders(String key, String value) { assertEquals(threadPool.getThreadContext().getHeader(key), value); assertEquals(threadPool.getThreadContext().getTransient(key), value);