Skip to content

Commit

Permalink
Change input to be string instead of ADTaskProfileRequest
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Ohlsen <[email protected]>
  • Loading branch information
ohltyler committed Dec 21, 2023
1 parent 7814079 commit cbe6c99
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.ad.transport.ADTaskProfileRequest;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -55,20 +54,20 @@ default ActionFuture<SearchResponse> searchAnomalyResults(SearchRequest searchRe

/**
* Get detector profile - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#profile-detector
* @param profileRequest profile request to fetch detector profile
* @param detectorId the detector ID to fetch the profile for
* @return ActionFuture of ADTaskProfileResponse
*/
default ActionFuture<ADTaskProfileResponse> getDetectorProfile(ADTaskProfileRequest profileRequest) {
default ActionFuture<ADTaskProfileResponse> getDetectorProfile(String detectorId) {
PlainActionFuture<ADTaskProfileResponse> actionFuture = PlainActionFuture.newFuture();
getDetectorProfile(profileRequest, actionFuture);
getDetectorProfile(detectorId, actionFuture);
return actionFuture;
}

/**
* Get detector profile - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#profile-detector
* @param profileRequest profile request to fetch detector profile
* @param detectorId the detector ID to fetch the profile for
* @param listener a listener to be notified of the result
*/
void getDetectorProfile(ADTaskProfileRequest profileRequest, ActionListener<ADTaskProfileResponse> listener);
void getDetectorProfile(String detectorId, ActionListener<ADTaskProfileResponse> listener);

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,35 @@

package org.opensearch.ad.client;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import java.util.function.Predicate;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ad.constant.ADCommonName;
import org.opensearch.ad.transport.ADTaskProfileAction;
import org.opensearch.ad.transport.ADTaskProfileRequest;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.ad.transport.SearchAnomalyDetectorAction;
import org.opensearch.ad.transport.SearchAnomalyResultAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;

public class AnomalyDetectionNodeClient implements AnomalyDetectionClient {
private final Client client;
private final ClusterService clusterService;
private final HotDataNodePredicate eligibleNodeFilter;

public AnomalyDetectionNodeClient(Client client) {
public AnomalyDetectionNodeClient(Client client, ClusterService clusterService) {
this.client = client;
this.clusterService = clusterService;
this.eligibleNodeFilter = new HotDataNodePredicate();
}

@Override
Expand All @@ -37,9 +51,56 @@ public void searchAnomalyResults(SearchRequest searchRequest, ActionListener<Sea
}

@Override
public void getDetectorProfile(ADTaskProfileRequest profileRequest, ActionListener<ADTaskProfileResponse> listener) {
this.client.execute(ADTaskProfileAction.INSTANCE, profileRequest, ActionListener.wrap(profileResponse -> {
listener.onResponse(profileResponse);
}, listener::onFailure));
public void getDetectorProfile(String detectorId, ActionListener<ADTaskProfileResponse> listener) {

// TODO: clean up
// Logic to determine eligible nodes comes from org.opensearch.timeseries.util.DiscoveryNodeFilterer
// There is no clean way to consume that within this client's constructor since it will be instantiated
// outside of this plugin typically. So we re-use that logic here
ClusterState state = this.clusterService.state();
final List<DiscoveryNode> eligibleNodes = new ArrayList<>();
for (DiscoveryNode node : state.nodes()) {
if (this.eligibleNodeFilter.test(node)) {
eligibleNodes.add(node);
}
}
final DiscoveryNode[] eligibleNodesAsArray = eligibleNodes.toArray(new DiscoveryNode[0]);

ADTaskProfileRequest profileRequest = new ADTaskProfileRequest(detectorId, eligibleNodesAsArray);
this.client.execute(ADTaskProfileAction.INSTANCE, profileRequest, getADTaskProfileResponseActionListener(listener));
}

// We need to wrap AD-specific response type listeners around an internal listener, and re-generate the response from a generic
// ActionResponse. This is needed to prevent classloader issues and ClassCastExceptions when executed by other plugins.
private ActionListener<ADTaskProfileResponse> getADTaskProfileResponseActionListener(ActionListener<ADTaskProfileResponse> listener) {
ActionListener<ADTaskProfileResponse> internalListener = ActionListener
.wrap(profileResponse -> { listener.onResponse(profileResponse); }, listener::onFailure);
ActionListener<ADTaskProfileResponse> actionListener = wrapActionListener(internalListener, actionResponse -> {
ADTaskProfileResponse response = ADTaskProfileResponse.fromActionResponse(actionResponse);
return response;
});
return actionListener;
}

private <T extends ActionResponse> ActionListener<T> wrapActionListener(
final ActionListener<T> listener,
final Function<ActionResponse, T> recreate
) {
ActionListener<T> actionListener = ActionListener.wrap(r -> {
listener.onResponse(recreate.apply(r));
;
}, e -> { listener.onFailure(e); });
return actionListener;
}

static class HotDataNodePredicate implements Predicate<DiscoveryNode> {
@Override
public boolean test(DiscoveryNode discoveryNode) {
return discoveryNode.isDataNode()
&& discoveryNode
.getAttributes()
.getOrDefault(ADCommonName.BOX_TYPE_KEY, ADCommonName.HOT_BOX_TYPE)
.equals(ADCommonName.HOT_BOX_TYPE);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@

package org.opensearch.ad.transport;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;

import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.nodes.BaseNodesResponse;
import org.opensearch.cluster.ClusterName;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

Expand All @@ -40,4 +46,18 @@ public List<ADTaskProfileNodeResponse> readNodesFrom(StreamInput in) throws IOEx
return in.readList(ADTaskProfileNodeResponse::readNodeResponse);
}

public static ADTaskProfileResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof ADTaskProfileResponse) {
return (ADTaskProfileResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new ADTaskProfileResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLModelGetResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ad.transport.ADTaskProfileRequest;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.core.action.ActionListener;

Expand Down Expand Up @@ -47,7 +46,7 @@ public void searchAnomalyResults(SearchRequest searchRequest, ActionListener<Sea
}

@Override
public void getDetectorProfile(ADTaskProfileRequest profileRequest, ActionListener<ADTaskProfileResponse> listener) {
public void getDetectorProfile(String detectorId, ActionListener<ADTaskProfileResponse> listener) {
listener.onResponse(profileResponse);
}
};
Expand All @@ -65,7 +64,7 @@ public void searchAnomalyResults() {

@Test
public void getDetectorProfile() {
assertEquals(profileResponse, anomalyDetectionClient.getDetectorProfile(new ADTaskProfileRequest("foo", null)).actionGet());
assertEquals(profileResponse, anomalyDetectionClient.getDetectorProfile("foo").actionGet());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.opensearch.ad.model.AnomalyDetectorType;
import org.opensearch.ad.transport.ADTaskProfileAction;
import org.opensearch.ad.transport.ADTaskProfileNodeResponse;
import org.opensearch.ad.transport.ADTaskProfileRequest;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterName;
Expand Down Expand Up @@ -67,7 +66,7 @@ public class AnomalyDetectionNodeClientTests extends HistoricalAnalysisIntegTest
@Before
public void setup() {
clientSpy = spy(client());
adClient = new AnomalyDetectionNodeClient(clientSpy);
adClient = new AnomalyDetectionNodeClient(clientSpy, clusterService());
}

@Test
Expand Down Expand Up @@ -150,17 +149,15 @@ public void testGetDetectorProfile_NoIndices() throws ExecutionException, Interr
deleteIndexIfExists(CommonName.CONFIG_INDEX);
deleteIndexIfExists(ALL_AD_RESULTS_INDEX_PATTERN);
deleteIndexIfExists(ADCommonName.DETECTION_STATE_INDEX);
DiscoveryNode localNode = clusterService().localNode();

profileFuture = mock(PlainActionFuture.class);
ADTaskProfileRequest profileRequest = new ADTaskProfileRequest("foo", localNode);
ADTaskProfileResponse response = adClient.getDetectorProfile(profileRequest).actionGet(10000);
ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000);
List<ADTaskProfileNodeResponse> responses = response.getNodes();

verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any());
// We should get node response back from the local node, but there should be no profiles found
assertEquals(1, responses.size());
assertNotEquals(0, responses.size());
assertEquals(null, responses.get(0).getAdTaskProfile());
verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any());

}

@Test
Expand All @@ -181,11 +178,11 @@ public void testGetDetectorProfile_Populated() {
return null;
}).when(clientSpy).execute(any(ADTaskProfileAction.class), any(), any());

ADTaskProfileRequest profileRequest = new ADTaskProfileRequest("foo", localNode);
ADTaskProfileResponse response = adClient.getDetectorProfile(profileRequest).actionGet(10000);
ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000);
String responseTaskId = response.getNodes().get(0).getAdTaskProfile().getTaskId();

verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any());
assertNotEquals(0, response.getNodes().size());
assertEquals(responseTaskId, adTaskProfile.getTaskId());
}

Expand Down

0 comments on commit cbe6c99

Please sign in to comment.