diff --git a/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java b/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java index 8dbbb9b..f7b2617 100644 --- a/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java +++ b/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java @@ -19,6 +19,7 @@ import ciir.umass.edu.learning.RankerFactory; import org.opensearch.ltr.breaker.LTRCircuitBreakerService; import org.opensearch.ltr.settings.LTRSettings; +import org.opensearch.ltr.rest.RestStatsLTRAction; import org.opensearch.ltr.stats.LTRStat; import org.opensearch.ltr.stats.LTRStats; import org.opensearch.ltr.stats.StatName; @@ -26,6 +27,8 @@ import org.opensearch.ltr.stats.suppliers.PluginHealthStatusSupplier; import org.opensearch.ltr.stats.suppliers.StoreStatsSupplier; import org.opensearch.ltr.stats.suppliers.CounterSupplier; +import org.opensearch.ltr.transport.LTRStatsAction; +import org.opensearch.ltr.transport.TransportLTRStatsAction; import com.o19s.es.explore.ExplorerQueryBuilder; import com.o19s.es.ltr.action.AddFeaturesToSetAction; import com.o19s.es.ltr.action.CachesStatsAction; @@ -201,6 +204,7 @@ public List getRestHandlers(Settings settings, RestController restC list.add(new RestFeatureStoreCaches()); list.add(new RestCreateModelFromSet()); list.add(new RestAddFeatureToSet()); + list.add(new RestStatsLTRAction(ltrStats)); return unmodifiableList(list); } @@ -212,7 +216,8 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(ClearCachesAction.INSTANCE, TransportClearCachesAction.class), new ActionHandler<>(AddFeaturesToSetAction.INSTANCE, TransportAddFeatureToSetAction.class), new ActionHandler<>(CreateModelFromSetAction.INSTANCE, TransportCreateModelFromSetAction.class), - new ActionHandler<>(ListStoresAction.INSTANCE, TransportListStoresAction.class))); + new ActionHandler<>(ListStoresAction.INSTANCE, TransportListStoresAction.class), + new ActionHandler<>(LTRStatsAction.INSTANCE, TransportLTRStatsAction.class))); } @Override diff --git a/src/main/java/org/opensearch/ltr/rest/RestStatsLTRAction.java b/src/main/java/org/opensearch/ltr/rest/RestStatsLTRAction.java new file mode 100644 index 0000000..8bbd534 --- /dev/null +++ b/src/main/java/org/opensearch/ltr/rest/RestStatsLTRAction.java @@ -0,0 +1,164 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.ltr.rest; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.transport.LTRStatsAction; +import org.opensearch.ltr.transport.LTRStatsRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestActions; + +import java.io.IOException; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.stream.Collectors; + +import static com.o19s.es.ltr.LtrQueryParserPlugin.LTR_BASE_URI; +import static com.o19s.es.ltr.LtrQueryParserPlugin.LTR_LEGACY_BASE_URI; + +/** + * Provide an API to get information on the plugin usage and + * performance, such as + * + */ +public class RestStatsLTRAction extends BaseRestHandler { + private static final String NAME = "learning_to_rank_stats"; + private final LTRStats ltrStats; + + public RestStatsLTRAction(final LTRStats ltrStats) { + this.ltrStats = ltrStats; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public List routes() { + return List.of(); + } + + @Override + public List replacedRoutes() { + return List.of( + new ReplacedRoute( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_BASE_URI, "/{nodeId}/stats/"), + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_LEGACY_BASE_URI, "/{nodeId}/stats/") + ), + new ReplacedRoute( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_BASE_URI, "/{nodeId}/stats/{stat}"), + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_LEGACY_BASE_URI, "/{nodeId}/stats/{stat}") + ), + new ReplacedRoute( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_BASE_URI, "/stats/"), + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_LEGACY_BASE_URI, "/stats/") + ), + new ReplacedRoute( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_BASE_URI, "/stats/{stat}"), + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_LEGACY_BASE_URI, "/stats/{stat}") + ) + ); + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + protected RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException { + final LTRStatsRequest ltrStatsRequest = getRequest(request); + return (channel) -> client.execute(LTRStatsAction.INSTANCE, + ltrStatsRequest, + new RestActions.NodesResponseRestListener(channel)); + } + + /** + * Creates a LTRStatsRequest from a RestRequest + * + * @param request RestRequest + * @return LTRStatsRequest + */ + private LTRStatsRequest getRequest(final RestRequest request) { + final LTRStatsRequest ltrStatsRequest = new LTRStatsRequest( + splitCommaSeparatedParam(request, "nodeId") + ); + ltrStatsRequest.timeout(request.param("timeout")); + + final List requestedStats = List.of(splitCommaSeparatedParam(request, "stat")); + + final Set validStats = ltrStats.getStats().keySet(); + + if (isAllStatsRequested(requestedStats)) { + ltrStatsRequest.addAll(validStats); + + } else { + ltrStatsRequest.addAll(getStatsToBeRetrieved(request, validStats, requestedStats)); + } + + return ltrStatsRequest; + } + + private Set getStatsToBeRetrieved( + final RestRequest request, + final Set validStats, + final List requestedStats) { + + if (requestedStats.contains(LTRStatsRequest.ALL_STATS_KEY)) { + throw new IllegalArgumentException(String.format("Request %s contains both %s and individual stats", + request.path(), LTRStatsRequest.ALL_STATS_KEY)); + } + + final Set invalidStats = requestedStats.stream() + .filter(s -> !validStats.contains(s)) + .collect(Collectors.toSet()); + + if (!invalidStats.isEmpty()) { + throw new IllegalArgumentException( + unrecognized(request, invalidStats, new HashSet<>(requestedStats), "stat")); + } + + return new HashSet<>(requestedStats); + } + + private boolean isAllStatsRequested(final List requestedStats) { + return requestedStats.isEmpty() + || (requestedStats.size() == 1 && requestedStats.contains(LTRStatsRequest.ALL_STATS_KEY)); + } + + private String[] splitCommaSeparatedParam(final RestRequest request, final String paramName) { + final String param = request.param(paramName); + + if (param == null) { + return new String[0]; + } else { + return param.split(","); + } + } +} diff --git a/src/main/java/org/opensearch/ltr/transport/LTRStatsAction.java b/src/main/java/org/opensearch/ltr/transport/LTRStatsAction.java new file mode 100644 index 0000000..1dc280a --- /dev/null +++ b/src/main/java/org/opensearch/ltr/transport/LTRStatsAction.java @@ -0,0 +1,38 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.ltr.transport; + +import org.opensearch.action.ActionRequestBuilder; +import org.opensearch.action.ActionType; +import org.opensearch.client.OpenSearchClient; + +public class LTRStatsAction extends ActionType { + public static final String NAME = "cluster:admin/ltr/stats"; + public static final LTRStatsAction INSTANCE = new LTRStatsAction(); + + protected LTRStatsAction() { + super(NAME, LTRStatsNodesResponse::new); + } + + public static class LTRStatsRequestBuilder + extends ActionRequestBuilder { + private static final String[] nodeIds = null; + + protected LTRStatsRequestBuilder(final OpenSearchClient client) { + super(client, INSTANCE, new LTRStatsRequest(nodeIds)); + } + } +} diff --git a/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeRequest.java b/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeRequest.java new file mode 100644 index 0000000..f2b5cb7 --- /dev/null +++ b/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeRequest.java @@ -0,0 +1,48 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.ltr.transport; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; + +/** + * LTRStatsNodeRequest to get a node stat + */ +public class LTRStatsNodeRequest extends TransportRequest { + private final LTRStatsRequest ltrStatsRequest; + + public LTRStatsNodeRequest(final LTRStatsRequest ltrStatsRequest) { + this.ltrStatsRequest = ltrStatsRequest; + } + + public LTRStatsNodeRequest(final StreamInput in) throws IOException { + super(in); + ltrStatsRequest = new LTRStatsRequest(in); + } + + public LTRStatsRequest getLTRStatsNodesRequest() { + return ltrStatsRequest; + } + + @Override + public void writeTo(final StreamOutput out) throws IOException { + super.writeTo(out); + ltrStatsRequest.writeTo(out); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeResponse.java b/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeResponse.java new file mode 100644 index 0000000..afea179 --- /dev/null +++ b/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeResponse.java @@ -0,0 +1,63 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.ltr.transport; + +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; + +public class LTRStatsNodeResponse extends BaseNodeResponse implements ToXContentFragment { + + private final Map statsMap; + + LTRStatsNodeResponse(final StreamInput in) throws IOException { + super(in); + this.statsMap = in.readMap(StreamInput::readString, StreamInput::readGenericValue); + } + + LTRStatsNodeResponse(final DiscoveryNode node, final Map statsToValues) { + super(node); + this.statsMap = statsToValues; + } + + public static LTRStatsNodeResponse readStats(final StreamInput in) throws IOException { + return new LTRStatsNodeResponse(in); + } + + public Map getStatsMap() { + return statsMap; + } + + @Override + public void writeTo(final StreamOutput out) throws IOException { + super.writeTo(out); + out.writeMap(statsMap, StreamOutput::writeString, StreamOutput::writeGenericValue); + } + + public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { + for (Map.Entry stat : statsMap.entrySet()) { + builder.field(stat.getKey(), stat.getValue()); + } + + return builder; + } +} diff --git a/src/main/java/org/opensearch/ltr/transport/LTRStatsNodesResponse.java b/src/main/java/org/opensearch/ltr/transport/LTRStatsNodesResponse.java new file mode 100644 index 0000000..db7c357 --- /dev/null +++ b/src/main/java/org/opensearch/ltr/transport/LTRStatsNodesResponse.java @@ -0,0 +1,82 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.ltr.transport; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class LTRStatsNodesResponse extends BaseNodesResponse implements ToXContent { + private static final String NODES_KEY = "nodes"; + private final Map clusterStats; + + public LTRStatsNodesResponse(final StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(LTRStatsNodeResponse::readStats), in.readList(FailedNodeException::new)); + clusterStats = in.readMap(); + } + + public LTRStatsNodesResponse( + final ClusterName clusterName, + final List nodeResponses, + final List failures, Map clusterStats) { + super(clusterName, nodeResponses, failures); + this.clusterStats = clusterStats; + } + + Map getClusterStats() { + return clusterStats; + } + + @Override + protected List readNodesFrom(final StreamInput in) throws IOException { + return in.readList(LTRStatsNodeResponse::readStats); + } + + @Override + protected void writeNodesTo(final StreamOutput out, final List nodeResponses) throws IOException { + out.writeList(nodeResponses); + } + + @Override + public void writeTo(final StreamOutput out) throws IOException { + super.writeTo(out); + out.writeMap(clusterStats); + } + + @Override + public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { + for (final Map.Entry clusterStat : clusterStats.entrySet()) { + builder.field(clusterStat.getKey(), clusterStat.getValue()); + } + + builder.startObject(NODES_KEY); + for (final LTRStatsNodeResponse ltrStats : getNodes()) { + builder.startObject(ltrStats.getNode().getId()); + ltrStats.toXContent(builder, params); + builder.endObject(); + } + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ltr/transport/LTRStatsRequest.java b/src/main/java/org/opensearch/ltr/transport/LTRStatsRequest.java new file mode 100644 index 0000000..f618a29 --- /dev/null +++ b/src/main/java/org/opensearch/ltr/transport/LTRStatsRequest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.ltr.transport; + + +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; + +public class LTRStatsRequest extends BaseNodesRequest { + + /** + * Key indicating all stats should be retrieved + */ + public static final String ALL_STATS_KEY = "_all"; + + private final Set statsToBeRetrieved; + + public LTRStatsRequest(final StreamInput in) throws IOException { + super(in); + statsToBeRetrieved = in.readSet(StreamInput::readString); + } + + /** + * Constructor + * + * @param nodeIds nodeIds of nodes' stats to be retrieved + */ + public LTRStatsRequest(final String... nodeIds) { + super(nodeIds); + statsToBeRetrieved = new HashSet<>(); + } + + /** + * Adds a stat to the set of stats to be retrieved + * + * @param stat name of the stat + */ + public void addStat(final String stat) { + statsToBeRetrieved.add(stat); + } + + /** + * Add all stats to be retrieved + * + * @param statsToBeAdded set of stats to be retrieved + */ + public void addAll(final Set statsToBeAdded) { + statsToBeRetrieved.addAll(statsToBeAdded); + } + + /** + * Remove all stats from retrieval set + */ + public void clear() { + statsToBeRetrieved.clear(); + } + + /** + * Get the set that tracks which stats should be retrieved + * + * @return the set that contains the stat names marked for retrieval + */ + public Set getStatsToBeRetrieved() { + return statsToBeRetrieved; + } + + @Override + public void writeTo(final StreamOutput out) throws IOException { + super.writeTo(out); + out.writeStringCollection(statsToBeRetrieved); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/ltr/transport/TransportLTRStatsAction.java b/src/main/java/org/opensearch/ltr/transport/TransportLTRStatsAction.java new file mode 100644 index 0000000..7303c5c --- /dev/null +++ b/src/main/java/org/opensearch/ltr/transport/TransportLTRStatsAction.java @@ -0,0 +1,95 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.ltr.transport; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class TransportLTRStatsAction extends + TransportNodesAction { + + private final LTRStats ltrStats; + + @Inject + public TransportLTRStatsAction( + final ThreadPool threadPool, + final ClusterService clusterService, + final TransportService transportService, + final ActionFilters actionFilters, + final LTRStats ltrStats) { + + super(LTRStatsAction.NAME, threadPool, clusterService, transportService, + actionFilters, LTRStatsRequest::new, LTRStatsNodeRequest::new, + ThreadPool.Names.MANAGEMENT, LTRStatsNodeResponse.class); + this.ltrStats = ltrStats; + } + + @Override + protected LTRStatsNodesResponse newResponse( + final LTRStatsRequest request, + final List nodeResponses, + final List failures) { + + final Set statsToBeRetrieved = request.getStatsToBeRetrieved(); + final Map clusterStats = ltrStats.getClusterStats() + .entrySet() + .stream() + .filter(e -> statsToBeRetrieved.contains(e.getKey())) + .collect( + Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getValue()) + ); + + return new LTRStatsNodesResponse(clusterService.getClusterName(), nodeResponses, failures, clusterStats); + } + + @Override + protected LTRStatsNodeRequest newNodeRequest(final LTRStatsRequest request) { + return new LTRStatsNodeRequest(request); + } + + @Override + protected LTRStatsNodeResponse newNodeResponse(final StreamInput in) throws IOException { + return new LTRStatsNodeResponse(in); + } + + @Override + protected LTRStatsNodeResponse nodeOperation(final LTRStatsNodeRequest request) { + final LTRStatsRequest ltrStatsRequest = request.getLTRStatsNodesRequest(); + final Set statsToBeRetrieved = ltrStatsRequest.getStatsToBeRetrieved(); + + final Map statValues = ltrStats.getNodeStats() + .entrySet() + .stream() + .filter(e -> statsToBeRetrieved.contains(e.getKey())) + .collect( + Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getValue()) + ); + return new LTRStatsNodeResponse(clusterService.localNode(), statValues); + } +} diff --git a/src/test/java/org/opensearch/ltr/transport/TransportLTRStatsActionTests.java b/src/test/java/org/opensearch/ltr/transport/TransportLTRStatsActionTests.java new file mode 100644 index 0000000..9001f22 --- /dev/null +++ b/src/test/java/org/opensearch/ltr/transport/TransportLTRStatsActionTests.java @@ -0,0 +1,91 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.ltr.transport; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ltr.stats.LTRStat; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.stats.StatName; +import org.opensearch.ltr.stats.suppliers.CounterSupplier; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.transport.TransportService; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +public class TransportLTRStatsActionTests extends OpenSearchIntegTestCase { + + private TransportLTRStatsAction action; + private LTRStats ltrStats; + private Map> statsMap; + private StatName clusterStatName; + private StatName nodeStatName; + + @Before + public void setUp() throws Exception { + super.setUp(); + + clusterStatName = StatName.LTR_PLUGIN_STATUS; + nodeStatName = StatName.LTR_CACHE_STATS; + + statsMap = new HashMap<>(); + statsMap.put(clusterStatName.getName(), new LTRStat<>(false, new CounterSupplier())); + statsMap.put(nodeStatName.getName(), new LTRStat<>(true, () -> "test")); + + ltrStats = new LTRStats(statsMap); + + action = new TransportLTRStatsAction( + client().threadPool(), + clusterService(), + mock(TransportService.class), + mock(ActionFilters.class), + ltrStats + ); + } + + @Test + public void testNewResponse() { + String[] nodeIds = null; + LTRStatsRequest ltrStatsRequest = new LTRStatsRequest(nodeIds); + ltrStatsRequest.addAll(ltrStats.getStats().keySet()); + + List responses = new ArrayList<>(); + List failures = new ArrayList<>(); + + LTRStatsNodesResponse ltrStatsResponse = action.newResponse(ltrStatsRequest, responses, failures); + assertEquals(1, ltrStatsResponse.getClusterStats().size()); + } + + @Test + public void testNodeOperation() { + String[] nodeIds = null; + LTRStatsRequest ltrStatsRequest = new LTRStatsRequest(nodeIds); + ltrStatsRequest.addAll(ltrStats.getStats().keySet()); + + LTRStatsNodeResponse response = action.nodeOperation(new LTRStatsNodeRequest(ltrStatsRequest)); + + Map stats = response.getStatsMap(); + + assertEquals(1, stats.size()); + } +}