diff --git a/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java b/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java index c4fdb6644..08ae805ad 100644 --- a/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java +++ b/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java @@ -8,6 +8,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.transport.ADTaskProfileResponse; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; @@ -40,7 +41,7 @@ default ActionFuture searchAnomalyDetectors(SearchRequest search */ default ActionFuture searchAnomalyResults(SearchRequest searchRequest) { PlainActionFuture actionFuture = PlainActionFuture.newFuture(); - searchAnomalyDetectors(searchRequest, actionFuture); + searchAnomalyResults(searchRequest, actionFuture); return actionFuture; } @@ -51,4 +52,22 @@ default ActionFuture searchAnomalyResults(SearchRequest searchRe */ void searchAnomalyResults(SearchRequest searchRequest, ActionListener listener); + /** + * Get detector profile - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#profile-detector + * @param detectorId the detector ID to fetch the profile for + * @return ActionFuture of ADTaskProfileResponse + */ + default ActionFuture getDetectorProfile(String detectorId) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + getDetectorProfile(detectorId, actionFuture); + return actionFuture; + } + + /** + * Get detector profile - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#profile-detector + * @param detectorId the detector ID to fetch the profile for + * @param listener a listener to be notified of the result + */ + void getDetectorProfile(String detectorId, ActionListener listener); + } diff --git a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java index a9d690ca1..8deddc00a 100644 --- a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java +++ b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java @@ -5,18 +5,29 @@ package org.opensearch.ad.client; +import java.util.function.Function; + import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.transport.ADTaskProfileAction; +import org.opensearch.ad.transport.ADTaskProfileRequest; +import org.opensearch.ad.transport.ADTaskProfileResponse; import org.opensearch.ad.transport.SearchAnomalyDetectorAction; import org.opensearch.ad.transport.SearchAnomalyResultAction; import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; public class AnomalyDetectionNodeClient implements AnomalyDetectionClient { private final Client client; + private final DiscoveryNodeFilterer nodeFilterer; - public AnomalyDetectionNodeClient(Client client) { + public AnomalyDetectionNodeClient(Client client, ClusterService clusterService) { this.client = client; + this.nodeFilterer = new DiscoveryNodeFilterer(clusterService); } @Override @@ -32,4 +43,34 @@ public void searchAnomalyResults(SearchRequest searchRequest, ActionListener listener) { + final DiscoveryNode[] eligibleNodes = this.nodeFilterer.getEligibleDataNodes(); + ADTaskProfileRequest profileRequest = new ADTaskProfileRequest(detectorId, eligibleNodes); + this.client.execute(ADTaskProfileAction.INSTANCE, profileRequest, getADTaskProfileResponseActionListener(listener)); + } + + // We need to wrap AD-specific response type listeners around an internal listener, and re-generate the response from a generic + // ActionResponse. This is needed to prevent classloader issues and ClassCastExceptions when executed by other plugins. + private ActionListener getADTaskProfileResponseActionListener(ActionListener listener) { + ActionListener internalListener = ActionListener + .wrap(profileResponse -> { listener.onResponse(profileResponse); }, listener::onFailure); + ActionListener actionListener = wrapActionListener(internalListener, actionResponse -> { + ADTaskProfileResponse response = ADTaskProfileResponse.fromActionResponse(actionResponse); + return response; + }); + return actionListener; + } + + private ActionListener wrapActionListener( + final ActionListener listener, + final Function recreate + ) { + ActionListener actionListener = ActionListener.wrap(r -> { + listener.onResponse(recreate.apply(r)); + ; + }, e -> { listener.onFailure(e); }); + return actionListener; + } } diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileResponse.java b/src/main/java/org/opensearch/ad/transport/ADTaskProfileResponse.java index 951f362f9..9886ca392 100644 --- a/src/main/java/org/opensearch/ad/transport/ADTaskProfileResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileResponse.java @@ -11,12 +11,18 @@ package org.opensearch.ad.transport; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; import java.util.List; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -40,4 +46,18 @@ public List readNodesFrom(StreamInput in) throws IOEx return in.readList(ADTaskProfileNodeResponse::readNodeResponse); } + public static ADTaskProfileResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof ADTaskProfileResponse) { + return (ADTaskProfileResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new ADTaskProfileResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into ADTaskProfileResponse", e); + } + } } diff --git a/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java b/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java index 2fe3e976f..95bfe24d6 100644 --- a/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java +++ b/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java @@ -13,6 +13,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.transport.ADTaskProfileResponse; import org.opensearch.core.action.ActionListener; public class AnomalyDetectionClientTests { @@ -20,7 +21,13 @@ public class AnomalyDetectionClientTests { AnomalyDetectionClient anomalyDetectionClient; @Mock - SearchResponse searchResponse; + SearchResponse searchDetectorsResponse; + + @Mock + SearchResponse searchResultsResponse; + + @Mock + ADTaskProfileResponse profileResponse; @Before public void setUp() { @@ -30,24 +37,34 @@ public void setUp() { anomalyDetectionClient = new AnomalyDetectionClient() { @Override public void searchAnomalyDetectors(SearchRequest searchRequest, ActionListener listener) { - listener.onResponse(searchResponse); + listener.onResponse(searchDetectorsResponse); } @Override public void searchAnomalyResults(SearchRequest searchRequest, ActionListener listener) { - listener.onResponse(searchResponse); + listener.onResponse(searchResultsResponse); + } + + @Override + public void getDetectorProfile(String detectorId, ActionListener listener) { + listener.onResponse(profileResponse); } }; } @Test public void searchAnomalyDetectors() { - assertEquals(searchResponse, anomalyDetectionClient.searchAnomalyDetectors(new SearchRequest()).actionGet()); + assertEquals(searchDetectorsResponse, anomalyDetectionClient.searchAnomalyDetectors(new SearchRequest()).actionGet()); } @Test public void searchAnomalyResults() { - assertEquals(searchResponse, anomalyDetectionClient.searchAnomalyResults(new SearchRequest()).actionGet()); + assertEquals(searchResultsResponse, anomalyDetectionClient.searchAnomalyResults(new SearchRequest()).actionGet()); + } + + @Test + public void getDetectorProfile() { + assertEquals(profileResponse, anomalyDetectionClient.getDetectorProfile("foo").actionGet()); } } diff --git a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java index e8a64c92a..6f5eaa37d 100644 --- a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java +++ b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java @@ -5,18 +5,25 @@ package org.opensearch.ad.client; -import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN; import static org.opensearch.ad.model.AnomalyDetector.DETECTOR_TYPE_FIELD; -import static org.opensearch.timeseries.TestHelpers.matchAllRequest; import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.junit.Before; import org.junit.Test; import org.opensearch.action.search.SearchRequest; @@ -24,12 +31,21 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.HistoricalAnalysisIntegTestCase; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorType; +import org.opensearch.ad.transport.ADTaskProfileAction; +import org.opensearch.ad.transport.ADTaskProfileNodeResponse; +import org.opensearch.ad.transport.ADTaskProfileResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.action.ActionListener; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; import com.google.common.collect.ImmutableList; @@ -38,22 +54,26 @@ // The exhaustive set of transport action scenarios are within the respective transport action // test suites themselves. We do not want to unnecessarily duplicate all of those tests here. public class AnomalyDetectionNodeClientTests extends HistoricalAnalysisIntegTestCase { + private final Logger logger = LogManager.getLogger(this.getClass()); private String indexName = "test-data"; private Instant startTime = Instant.now().minus(2, ChronoUnit.DAYS); + private Client clientSpy; private AnomalyDetectionNodeClient adClient; - private PlainActionFuture future; + private PlainActionFuture searchResponseFuture; + private PlainActionFuture profileFuture; @Before public void setup() { - adClient = new AnomalyDetectionNodeClient(client()); + clientSpy = spy(client()); + adClient = new AnomalyDetectionNodeClient(clientSpy, clusterService()); } @Test public void testSearchAnomalyDetectors_NoIndices() { deleteIndexIfExists(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); - SearchResponse searchResponse = adClient.searchAnomalyDetectors(matchAllRequest()).actionGet(10000); + SearchResponse searchResponse = adClient.searchAnomalyDetectors(TestHelpers.matchAllRequest()).actionGet(10000); assertEquals(0, searchResponse.getInternalResponse().hits().getTotalHits().value); } @@ -62,13 +82,13 @@ public void testSearchAnomalyDetectors_Empty() throws IOException { deleteIndexIfExists(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); createDetectorIndex(); - SearchResponse searchResponse = adClient.searchAnomalyDetectors(matchAllRequest()).actionGet(10000); + SearchResponse searchResponse = adClient.searchAnomalyDetectors(TestHelpers.matchAllRequest()).actionGet(10000); assertEquals(0, searchResponse.getInternalResponse().hits().getTotalHits().value); } @Test public void searchAnomalyDetectors_Populated() throws IOException { - ingestTestData(indexName, startTime, 1, "test", 3000); + ingestTestData(indexName, startTime, 1, "test", 10); String detectorType = AnomalyDetectorType.SINGLE_ENTITY.name(); AnomalyDetector detector = TestHelpers .randomAnomalyDetector( @@ -94,18 +114,18 @@ public void searchAnomalyDetectors_Populated() throws IOException { @Test public void testSearchAnomalyResults_NoIndices() { - future = mock(PlainActionFuture.class); + searchResponseFuture = mock(PlainActionFuture.class); SearchRequest request = new SearchRequest().indices(new String[] {}); - adClient.searchAnomalyResults(request, future); - verify(future).onFailure(any(IllegalArgumentException.class)); + adClient.searchAnomalyResults(request, searchResponseFuture); + verify(searchResponseFuture).onFailure(any(IllegalArgumentException.class)); } @Test public void testSearchAnomalyResults_Empty() throws IOException { createADResultIndex(); SearchResponse searchResponse = adClient - .searchAnomalyResults(matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN)) + .searchAnomalyResults(TestHelpers.matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN)) .actionGet(10000); assertEquals(0, searchResponse.getInternalResponse().hits().getTotalHits().value); } @@ -117,11 +137,52 @@ public void testSearchAnomalyResults_Populated() throws IOException { String adResultId = createADResult(TestHelpers.randomAnomalyDetectResult()); SearchResponse searchResponse = adClient - .searchAnomalyResults(matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN)) + .searchAnomalyResults(TestHelpers.matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN)) .actionGet(10000); - assertEquals(1, searchResponse.getInternalResponse().hits().getTotalHits().value); + assertEquals(1, searchResponse.getInternalResponse().hits().getTotalHits().value); assertEquals(adResultId, searchResponse.getInternalResponse().hits().getAt(0).getId()); } + @Test + public void testGetDetectorProfile_NoIndices() throws ExecutionException, InterruptedException { + deleteIndexIfExists(CommonName.CONFIG_INDEX); + deleteIndexIfExists(ALL_AD_RESULTS_INDEX_PATTERN); + deleteIndexIfExists(ADCommonName.DETECTION_STATE_INDEX); + + profileFuture = mock(PlainActionFuture.class); + ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000); + List responses = response.getNodes(); + + assertNotEquals(0, responses.size()); + assertEquals(null, responses.get(0).getAdTaskProfile()); + verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any()); + + } + + @Test + public void testGetDetectorProfile_Populated() { + DiscoveryNode localNode = clusterService().localNode(); + ADTaskProfile adTaskProfile = new ADTaskProfile("foo-task-id", 0, 0L, false, 0, 0L, localNode.getId()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + + ActionListener listener = (ActionListener) args[2]; + ADTaskProfileNodeResponse nodeResponse = new ADTaskProfileNodeResponse(localNode, adTaskProfile, null); + + List nodeResponses = Arrays.asList(nodeResponse); + listener.onResponse(new ADTaskProfileResponse(new ClusterName("test-cluster"), nodeResponses, Collections.emptyList())); + + return null; + }).when(clientSpy).execute(any(ADTaskProfileAction.class), any(), any()); + + ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000); + String responseTaskId = response.getNodes().get(0).getAdTaskProfile().getTaskId(); + + assertNotEquals(0, response.getNodes().size()); + assertEquals(responseTaskId, adTaskProfile.getTaskId()); + verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any()); + } + } diff --git a/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java b/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java index 2f6d555d1..2654113d7 100644 --- a/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java @@ -21,7 +21,9 @@ import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.cluster.ClusterName; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import com.google.common.collect.ImmutableList; @@ -53,4 +55,41 @@ public void testSerializeResponse() throws IOException { assertEquals(1, response2.getNodes().size()); assertEquals(taskId, response2.getNodes().get(0).getAdTaskProfile().getTaskId()); } + + public void testFromActionResponse() throws IOException { + String taskId = randomAlphaOfLength(5); + ADTaskProfile adTaskProfile = new ADTaskProfile(); + adTaskProfile.setTaskId(taskId); + Version remoteAdVersion = Version.CURRENT; + ADTaskProfileNodeResponse nodeResponse = new ADTaskProfileNodeResponse(randomDiscoveryNode(), adTaskProfile, remoteAdVersion); + + List nodeResponses = ImmutableList.of(nodeResponse); + ADTaskProfileResponse response = new ADTaskProfileResponse(new ClusterName("test"), nodeResponses, ImmutableList.of()); + + ADTaskProfileResponse reserializedResponse = ADTaskProfileResponse.fromActionResponse((ActionResponse) response); + assertEquals(1, reserializedResponse.getNodes().size()); + assertEquals(taskId, reserializedResponse.getNodes().get(0).getAdTaskProfile().getTaskId()); + + BytesStreamOutput output = new BytesStreamOutput(); + response.writeNodesTo(output, nodeResponses); + StreamInput input = output.bytes().streamInput(); + + ActionResponse invalidActionResponse = new TestActionResponse(input); + assertThrows(Exception.class, () -> ADTaskProfileResponse.fromActionResponse(invalidActionResponse)); + + } + + // A test ActionResponse class with an inactive writeTo class. Used to ensure exceptions + // are thrown when parsing implementations of such class. + private class TestActionResponse extends ActionResponse { + public TestActionResponse(StreamInput in) throws IOException { + super(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + return; + } + } + }